diff --git a/docs/README_ETP.md b/docs/README_ETP.md new file mode 100644 index 0000000000..52368579fb --- /dev/null +++ b/docs/README_ETP.md @@ -0,0 +1,784 @@ +# Extended Tensor Parallelism (ETP) + +## Overview + +Extended Tensor Parallelism (ETP) is a **light-weight**, **high-performance** and **memory-efficient** distributed training strategy implemented in TransformerEngine. It shards weight tensors across an ETP process group and reconstructs them on-demand via async all-gather, enabling training of larger models without sacrificing throughput by overlapping communication with computation. + +ETP applies to any TE module that wraps a `Linear` layer: `Linear`, `LayerNormLinear`, `LayerNormMLP` (for dense models), and `GroupedLinear` (for MoE models). When used with `GroupedLinear`, ETP provides additional batched coalesced all-gather support for gathering multiple expert weights in a single NCCL operation. + +ETP supports all TE low-precision formats (FP8, MXFP8, NVFP4) with a **quantize-then-gather** strategy: each rank quantizes only its local shard before the all-gather, so wire bandwidth scales with the quantized size (0.5× for FP8, 0.25× for NVFP4) rather than the full BF16 weight. + +--- + +## Performance + +TODO(shiqingf): add performance for Ultra model in nvfp4. + +---- + +## Features + +### User-Visible Features + +| Feature | Description | +|---|---| +| **Weight sharding** | Weights sharded 1/N across ETP group along `out_features`, reducing per-GPU VRAM | +| **Async prefetch** | Next layer's weight all-gather overlaps with current layer's GEMM in both forward (prefetches `next_w`) and backward (prefetches `prev_w`); controlled by `ETPConfig.weight_prefetch` | +| **NVFP4 support** | Full 4-bit quantized all-gather with interleaved-format post-processing | +| **FP8 / MXFP8 support** | Quantized shards with ETP-group amax reduction | +| **Routed expert support** | Batched coalesced all-gather for all experts in a MoE layer (GroupedLinear) | +| **Composable with TP/SP** | Orthogonal to tensor parallelism and sequence parallelism | +| **CUDA Graphs compatible** | Dense-chain prefetches captured in graphs; expert-chain runs eagerly. DDP RS serialized via `register_grad_accum_hook` (called from `_finalize_wgrad` for eager params, from `_CudagraphReplayNode.backward` for graphed params). Forward drains at CG/eager boundary prevent IB races. | +| **Debug naming** | `tag_etp_params_with_names(model)` populates human-readable names on every `ETPShardedParam`; the prefetch-link table is printed atomically at the start of the second forward pass | + +### Implementation Mechanisms + +| Mechanism | Description | +|---|---| +| **Alignment padding** | Shards padded to `ETPConfig.pad_for_alignment × etp_size` rows at construction via `get_padded_shard()`; only last rank carries padding (`is_padded_last_rank`); padding stripped in `_strip_padding()` both post-gather (before GEMM) and post-reduce-scatter (before wgrad accumulation) | +| **Fine-grained weight scheduling** | Each weight has its own `ETPWeightState` lifecycle and is scheduled independently via a doubly-linked list (`next_w`/`prev_w`), enabling per-weight AG/RS overlap at single-weight granularity. Two independent chains are maintained: one for dense params (mamba/attn/shared_expert) and one for expert params (grouped_fc1/grouped_fc2) | +| **Separate AG and RS state** | All-gather state (`state`) and reduce-scatter state (`rs_state`) are tracked independently per param, allowing forward and backward async ops to proceed without interference | +| **Shared CUDA streams** | AG and RS run on shared CUDA streams (`get_ag_stream()`, `get_rs_stream()`) across all chains; completion is signaled back via per-param CUDA events (`ag_event`, `rs_event`) that the compute stream waits on before consuming the result. Streams must be shared because `ag_event` is recorded on the AG stream during CUDA graph capture; using a different stream at replay would cause `ag_event.wait()` to see a stale recording | +| **Ticket-based buffer cache** | `ETPWeightCache` assigns persistent tickets via `reserve()`; buffers are lazily allocated on `get()` and returned to the pool on `release()`; `clear()` drops all buffers while keeping tickets valid for lazy re-allocation (used for CUDA Graph re-capture) | +| **Wgrad reduce-scatter** | Async reduce-scatter of weight gradients, deferred to overlap with next layer's wgrad RS; `_finalize_wgrad()` resets `rs_state`, strips padding, and accumulates the result into `param.main_grad`, returning a dummy-zero grad to autograd | + +--- + +## Design + +### Core Idea + +In standard Tensor Parallelism (TP), each GPU holds a shard of each weight and communicates activations. ETP goes one level deeper: **each weight is sharded along the `out_features` dimension (dim 0) across an ETP group of N GPUs**, so each GPU stores only 1/N of the weight. Before each GEMM, an all-gather reconstructs the full weight; after the backward GEMM, a reduce-scatter propagates the weight gradient back to the shards. + +``` +Standard column-parallel TP (TP=2, 2 GPUs, weight W of shape [K, M]): + GPU0 owns W[:K/2, :] (first half of out_features) + GPU1 owns W[K/2:, :] (second half of out_features) + +ETP (on top of column-parallel TP, ETP=2 per TP rank, 4 GPUs): + GPU0 (TP0, ETP0) owns W[:K/4, :] (first quarter of out_features) + GPU1 (TP0, ETP1) owns W[K/4:K/2, :] (second quarter of out_features) + GPU2 (TP1, ETP0) owns W[K/2:3K/4, :] + GPU3 (TP1, ETP1) owns W[3K/4:K, :] +``` + +ETP always shards along `out_features` regardless of the TP parallel mode (`column` or `row`). For `row` parallel mode, TP shards `in_features` while ETP shards `out_features`, making the two dimensions orthogonal. + +ETP is composable with TP and Sequence Parallelism for `Linear`, `LayerNormLinear`, and `LayerNormMLP`. The `etp_group` process group is orthogonal to the `tp_group`, giving a 2D parallelism grid. + + +### Weight Sharding + +#### Initialization + +Every rank independently allocates and initializes the **full** weight tensor, then slices out its local portion — there is no broadcast or communication during construction. + +``` +te.Linear.__init__(out_features=F, in_features=K, etp_group=group) +│ +├─ 1. Every rank: weight_tensor = torch.empty(F, K) ← full weight, same shape on all ranks +│ +├─ 2. reset_parameters() ← Kaiming-uniform init on every rank +│ identical seed ⇒ identical values on all ranks; slice is consistent without any comm +│ +└─ 3. wrap_module_params_etp(self, weight_names, etp_group) + │ + ├─ alignment = pad_for_alignment(16) × etp_size + │ pad_length = (alignment − F % alignment) % alignment + │ shard_size = (F + pad_length) // etp_size + │ + ├─ start = rank × shard_size + │ end = min((rank+1) × shard_size, F) ← clips real rows for last rank + │ shard = weight_tensor[start : end].clone() + │ + ├─ ETPShardedParam(shard) + │ .pad_length = pad_length + │ .is_padded_last_rank = (rank == etp_size−1 and pad_length > 0) + │ .group = etp_group + │ + ├─ module._parameters["weight"] = etp_shard ← replace nn.Parameter + │ + └─ del weight_tensor ← full buffer freed +``` + +Example: `F=63, K=32, etp_size=4, pad_for_alignment=16` + +``` +alignment=64, pad_length=1, shard_size=16 + +rank 0: rows [ 0:16] → ETPShardedParam [16, 32] pad_length=0 is_padded=False +rank 1: rows [16:32] → ETPShardedParam [16, 32] pad_length=0 is_padded=False +rank 2: rows [32:48] → ETPShardedParam [16, 32] pad_length=0 is_padded=False +rank 3: rows [48:63] → ETPShardedParam [15, 32] pad_length=1 is_padded=True +``` + +#### Padding and strip flow + +Padding is added **entering** each collective so all ranks contribute equal-sized chunks; it is stripped **exiting** each collective so downstream consumers see the real shape. + +``` +FORWARD + local shard [real_rows, K] (e.g. [15, 32] on last rank) + └─ get_padded_shard() → [shard_size, K] (e.g. [16, 32] zero row appended) + └─ all-gather → [padded_F, K] (e.g. [64, 32] across etp_size ranks) + └─ _strip_padding → [F, K] (e.g. [63, 32] ← weight seen by GEMM) + └─ GEMM → output [B, F] + +BACKWARD (wgrad path) + wgrad [B, F] (computed against stripped weight, so first dim is F not padded_F) + └─ _reduce_scatter pads: [F, K] → [padded_F, K] (re-pads before RS so chunks are equal) + └─ reduce-scatter → [shard_size, K] per rank + └─ _finalize_wgrad → _strip_padding → [real_rows, K] + └─ accumulated into param.main_grad (matches local shard shape) + └─ dummy zero grad returned to autograd +``` + +#### Wrapping call + +```python +# Called in Linear/LayerNormLinear/LayerNormMLP/GroupedLinear __init__ +if etp_group is not None: + wrap_module_params_etp(self, self.weight_names, etp_group) + del weight_tensor # free the temporary full-weight buffer +``` + +For `GroupedLinear` (MoE), `wrap_module_params_etp` is called with `is_grouped=True`, which additionally sets `weight_list` on the first expert's `ETPShardedParam` so all experts' weights can be batched together in a single coalesced all-gather. It also sets `chain_id='expert'` so expert params join the expert prefetch chain (separate from the dense chain). + +### State Machine + +Each `ETPShardedParam` tracks two independent state machines: one for the all-gather (`state`) and one for the reduce-scatter (`rs_state`). Each uses the same four-state enum: + +``` +NONE ──────────► ASYNC_WAIT ──────────► DATA_READY ──────────► NONE +(shard only) (AG/RS launched) (AG/RS complete, (consumed, + result in cache) back to shard) + +NONE ─────────────────────────────────► DATA_READY_SYNC ──────► NONE + (sync gather, (consumed) + result available) +``` + +The `DATA_READY_SYNC` state is used for on-demand synchronous gathers (cold start or when prefetch is disabled). `DATA_READY` is used after an async gather completes via `handle.wait()`. + +Transition validation is implemented but currently commented out in `_set_state()` / `_set_rs_state()` (guarded by `ETP_CONFIG.check_param_states`); both methods unconditionally set the new state in the current implementation. + +### Class Diagram + +
+Click to expand + +```mermaid +classDiagram + + %% ── Enums ──────────────────────────────────────────────────────────────── + class ETPWeightState { + <> + NONE + ASYNC_WAIT + DATA_READY + DATA_READY_SYNC + } + + %% ── Config ─────────────────────────────────────────────────────────────── + class ETPConfig { + <> + +int pad_for_alignment + +bool check_param_states + +bool weight_prefetch + } + + %% ── Core parameter class ───────────────────────────────────────────────── + class ETPShardedParam { + <> + $ _pending_rs_weight : ETPShardedParam + $ _first_weight_flag : bool + $ _chain_state : Dict[str, dict] + +ETPWeightState state + +ETPWeightState rs_state + +int _ag_ticket_fwd + +int _ag_ticket_bwd + +int _rs_ticket + +Event ag_event + +Event rs_event + +ETPShardHandle _prefetch_handle + +ETPShardHandle _wgrad_rs_handle + +Quantizer _quantizer + +bool did_cast_to_low_precision + +QuantizedTensor quantized + +int pad_length + +bool is_padded_last_rank + +bool prefetch_initialized + +ETPShardedParam next_w + +ETPShardedParam prev_w + +str chain_id + +bool is_routed_expert + +int expert_idx + +ProcessGroup group + +List weight_list + +Tensor wgrad_rs + +str _debug_name + +setup(weight_quantizer) + +_weights() List + +_set_state(new_state) + +_set_rs_state(new_state) + +_get_cache_key(dtype, fwd, reduce_scatter) tuple + +_unsharded_shape_padded() tuple + +_unsharded_shape() tuple + +_sharded_padded_shape() tuple + +get_padded_shard() Tensor + +_strip_padding(tensor) Tensor + +_quantize_if_needed(skip, flag) + +_all_gather_weight(async_op, ...) tuple + +_all_gather_weight_on_demand(fwd, ...) Tensor + +_get_prefetched_weight(fwd, ...) Tensor + +_wait_param_gather() + +_wait_reduce_scatter() + +all_gather_and_prefetch(fwd, ...) Tensor + +all_gather_and_prefetch_bwd() Tensor + +get_wgrad_tensor() Tensor + +_finalize_wgrad(param, wgrad_rs) [staticmethod] + +_reduce_scatter(wgrads, async_op) tuple + +wgrad_reduce_scatter(wgrad) + } + + %% ── Async all-gather handles ───────────────────────────────────────────── + class ETPShardHandle { + +Work handle + +List etp_shards + +bool reduce_scatter + +wait() + } + + class BatchedNVFP4AllGatherAsyncHandle { + <> + +List output_handles + +Work outer_async_handle + +bool _synchronized + +wait() + } + + class _NVFP4AllGatherAsyncHandle { + +NVFP4TensorStorage output + +Tensor columnwise_data_interleaved + +Tensor columnwise_scale_inv_interleaved + +int world_size + +Work async_handle + +bool _synchronized + +post_process_nvfp4_gather() + +wait() + } + + %% ── Buffer pool / ticket cache ─────────────────────────────────────────── + class _TicketSlot { + <> + +tuple key + +ETPShardedParam param + +dtype + +str chain_id + +bool reduce_scatter + +bool fwd + +Tensor buf + } + + class ETPWeightCache { + -Dict _pool + -Dict _slots + -int _next_ticket + -int _total_bytes + +reserve(param, dtype, fwd, reduce_scatter) int + +get(ticket) Tensor + +release(ticket) + +clear() + +reallocate_to_mempool(device, mempool) + -_allocate_buffer(param, dtype, reduce_scatter, fwd) Tensor + -_buf_bytes(shape, dtype) int + } + + %% ── External bases (simplified) ────────────────────────────────────────── + class torch_nn_Parameter { + <> + } + class QuantizedTensor { + <> + } + class NVFP4TensorStorage { + <> + } + + %% ── Relationships ──────────────────────────────────────────────────────── + + %% inheritance + torch_nn_Parameter <|-- ETPShardedParam + + %% state machine ownership + ETPShardedParam --> ETPWeightState : state / rs_state + + %% doubly-linked prefetch list (self-referential) + ETPShardedParam --> ETPShardedParam : next_w / prev_w + + %% grouped expert list (self-referential) + ETPShardedParam --> ETPShardedParam : weight_list + + %% in-flight prefetch / RS handles + ETPShardedParam --> ETPShardHandle : _prefetch_handle / _wgrad_rs_handle + + %% handle back-reference to shards for state transitions + ETPShardHandle --> ETPShardedParam : etp_shards + + %% handle polymorphism: plain Work or NVFP4-batched + ETPShardHandle --> BatchedNVFP4AllGatherAsyncHandle : handle (NVFP4 path) + + %% batched handle contains one entry per expert + BatchedNVFP4AllGatherAsyncHandle --> _NVFP4AllGatherAsyncHandle : output_handles + + %% config singleton controls all params + ETPShardedParam ..> ETPConfig : ETP_CONFIG + + %% buffer pool used via global singleton + ETPShardedParam ..> ETPWeightCache : reserve / get / release + + %% ticket slots + ETPWeightCache --> _TicketSlot : _slots + + %% quantized tensor stored per param + ETPShardedParam --> QuantizedTensor : quantized + + %% NVFP4 storage is a QuantizedTensor + NVFP4TensorStorage --|> QuantizedTensor + + %% NVFP4 handle output type + _NVFP4AllGatherAsyncHandle --> NVFP4TensorStorage : output +``` + +
+ +--- + +## Difference with FSDP + +FSDP (Fully Sharded Data Parallelism) and ETP both shard weight parameters, but they target different axes and serve different purposes: + +| Dimension | FSDP | ETP | +|---|---|---| +| **Sharding axis** | Data-parallel replicas | ETP process group (model-parallel dimension) | +| **Target layer** | All parameters uniformly | Any TE Linear, LayerNormLinear, LayerNormMLP, or GroupedLinear weight | +| **Communication** | All-gather before fwd, reduce-scatter after bwd | Same pattern, but orthogonal group | +| **State tracked** | PyTorch handles lifecycle | `ETPWeightState` state machine per param (separate for AG and RS) | +| **Quantization** | Framework-level, post-gather | **Quantize-then-gather** (lower bandwidth) | +| **Buffer management** | PyTorch flat-param storage | Ticket-based buffer pool per shape/dtype | +| **Prefetching** | PyTorch forward-hook prefetch | Lazy linked-list async prefetch across layers | +| **Gradient flow** | Reduce over data-parallel dim | Reduce-scatter over ETP dim | +| **Composability** | Wraps module hierarchy | Opt-in per-module via `etp_group` arg | + +**Key distinction**: FSDP shards across the *data-parallel dimension* (replicas processing different samples), while ETP shards across the *model-parallel dimension* (GPUs processing the same sample). They can coexist: a model can use FSDP for data parallelism and ETP for weight memory reduction simultaneously. + +A further practical difference is that ETP is **quantization-aware**: shards are quantized *before* the all-gather, so the wire bandwidth is proportional to the quantized size (e.g., FP4 = 1/4 of BF16), not the original weight size. FSDP gathers in full precision by default. + +--- + +## Two-Chain Architecture (Dense + Expert) + +ETP maintains **two independent prefetch chains** to cleanly separate dense and expert weight management: + +| Chain | Params | NCCL Group | CUDA Graph | +|-------|--------|-----------|------------| +| **Dense** (`chain_id='dense'`) | mamba, attention, shared expert | `PARAMETER_SHARDING_GROUP` | Captured in graphs | +| **Expert** (`chain_id='expert'`) | grouped_fc1, grouped_fc2 | `EXPERT_PARAMETER_SHARDING_GROUP` | Runs eagerly | + +Both chains share the same `ag_stream` / `rs_stream` (see "Shared Streams" below). + +### Why Two Chains Instead of One? + +The original design used a **single global chain** linking all ETP params: + +``` +Single chain (old): +CG(mamba.fc1 -> mamba.fc2) -> CG(shared_expert.fc1 -> shared_expert.fc2) -> EAGER(grouped_fc1 -> grouped_fc2) -> CG(next_mamba.fc1 -> ...) -> ... + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + crosses CG/eager boundary +``` + +This caused two problems: + +1. **Cross-chain prefetch crossing CG/eager boundary**: The single linked list linked dense params (captured in CUDA graphs) to expert params (running eagerly). The prefetch chain crossed the CG/eager boundary, causing the captured AG event sequence to include expert weight prefetches. At 64+ GPU IB scale, this interaction corrupted NCCL communicator progress tracking across graph replays and caused deadlocks. + +2. **Complex fencing**: Numerous `_drain_etp_side_streams()` fences were needed at every CG/eager boundary (forward expert compute entry, backward dispatch/combine, finalize_model_grads). These fences were fragile, hard to reason about, and didn't fully solve the 64-GPU hang. + +The two-chain design eliminates both problems: + +``` +Dense chain: CG(mamba.fc1 -> mamba.fc2) -> CG(shared_expert.fc1 -> shared_expert.fc2) -> CG(next_mamba.fc1 -> ...) -> ... +Expert chain: EAGER(grouped_fc1_L1 -> grouped_fc2_L1) -> EAGER(grouped_fc1_L2 -> grouped_fc2_L2) -> ... + (never crosses into CG, never uses PARAMETER_SHARDING_GROUP) +``` + +Each chain uses its own NCCL communicator and stays entirely within one execution mode (CG or eager). + +### Chain Construction + +Each chain builds its own doubly-linked list independently via per-chain state in `_chain_state`: + +``` +Dense chain: mamba.fc1 -> mamba.fc2 -> shared_expert.fc1 -> shared_expert.fc2 -> next_mamba.fc1 -> ... +Expert chain: grouped_fc1_layer1 -> grouped_fc2_layer1 -> grouped_fc1_layer2 -> ... +``` + +The `chain_id` is set automatically: `wrap_module_params_etp(..., is_grouped=True)` sets `chain_id='expert'`; all other params default to `chain_id='dense'`. + +### Shared Streams + +Both chains share the same `ag_stream` and `rs_stream`. Per-chain streams were considered but cause correctness issues: the `ag_event` CUDA event object is recorded on `ag_stream` during CUDA graph capture. If expert params used a different stream at replay time, `ag_event.wait()` would see a stale recording, producing Inf gradients. Shared streams avoid this while the chain-level isolation (no cross-chain `next_w`/`prev_w` links) provides the key benefit of preventing prefetch chains from crossing the CG/eager boundary. + +### Buffer Cache + +The single global `ETPWeightCache` serves both chains. Cache keys already include `expert_idx`, so dense and expert buffers never collide. `reallocate_to_mempool()` only migrates **dense-chain** buffers into the CUDA graph memory pool; expert-chain buffers remain in regular allocator memory. + +### Excluding Params from the Chain + +Setting `weight.prefetch_initialized = True` at construction skips chain registration entirely. Megatron uses this for the embedding and output-layer weights, which perform synchronous all-gathers and must not join the dense chain (they execute outside the CUDA graph boundary, and linking them into the dense chain would cause the chain to cross the CG/eager boundary, reproducing the same NCCL deadlock as the old single-chain design). Setting `_need_weight_prefetch = False` in addition disables the async path so these weights always do synchronous AG. + +### ETP + DDP Serialization (`register_grad_accum_hook`) + +ETP bypasses autograd's normal gradient flow: `wgrad_reduce_scatter` returns `None` for async RS (chain interior params), and `_finalize_wgrad` accumulates directly into `main_grad`. As a result, autograd's grad accumulator never fires for these params, and standard DDP backward hooks (`grad_acc.register_hook`) would never trigger. + +Without proper serialization, DDP reduce-scatter (IB) and ETP reduce-scatter (IB) can run concurrently on different CUDA streams at 64+ GPU IB scale, causing NCCL deadlock. + +The solution: `register_grad_accum_hook(grad_acc, hook)` stores the DDP hook on the `ETPShardedParam`. `_finalize_wgrad` calls the hook **manually** after RS wait + gradient accumulation: + +```python +# _finalize_wgrad (called after RS is waited and gradient accumulated) +param.main_grad.add_(wgrad_rs) # gradient accumulated +param.grad = dummy_grad # Python attr set (does NOT fire autograd) +param._grad_accum_hook() # manually triggers DDP register_grad_ready +``` + +This fires `register_grad_ready` at exactly the right serialization point, ensuring DDP RS launches only after ETP RS completes. The hook trigger differs by execution mode: + +| Weight type | Hook trigger location | When | +|---|---|---| +| **Graphed dense** (mamba/attn/shared_expert) | `_CudagraphReplayNode.backward` in `cuda_graphs.py` | After graph replay (Python, not captured) | +| **Eager expert** (grouped_fc1/fc2) | `_finalize_wgrad` in `extended_tensor_parallelism.py` | After RS wait + `main_grad.add_` (Python, every iteration) | +| **Eager chain head** (sync RS) | `_finalize_wgrad` called directly in `wgrad_reduce_scatter` | Immediately after sync RS completes | + +For graphed params: `_finalize_wgrad` runs during capture but the hook returns early (`is_graph_capturing()`). At replay, `_finalize_wgrad` doesn't re-run from Python (captured GPU ops only). `_CudagraphReplayNode.backward` explicitly triggers the hook after setting `grad_added_to_main_grad = True`. + +### Forward-Path Drains at CG/Eager Boundary + +Before eager expert compute starts (`_forward_mlp_expert_compute`), two drains ensure no in-flight IB ops race with expert backward: + +1. `_drain_etp_side_streams('dense')` — drains the dense ETP AG prefetch (e.g., `AG(next_mamba_fc1)` launched by the preceding shared_expert GEMM on `ag_stream`) +2. `_drain_param_gather()` — drains async DDP param all-gather from `--overlap-param-gather` + CG (the forward pre-hook `finish_param_sync` is skipped during graph capture/replay) + +--- + +## Scalability + +ETP scales along two independent dimensions: + +1. **ETP group size (N)**: Divides per-GPU weight memory by N. With N=8 and BF16 weights, a weight of 8 GB is reduced to 1 GB per GPU. With NVFP4, the same weight becomes 250 MB per GPU. + +2. **Number of experts (E)** (MoE only): Expert weights are gathered in parallel via a batched coalesced all-gather (`grouped_gather_along_first_dim`), so adding more experts within a MoE layer does not serialize the communication. + +**Combined scaling**: In a model with TP×ETP parallelism, the effective per-GPU weight size is `W / (TP × ETP)`. For example, TP=4 + ETP=8 gives 32× weight compression before training data parallelism is even considered. + +**Prefetch chain amortizes communication**: The linked-list prefetch means that for an L-layer model, L-1 all-gathers are completely hidden behind compute. Only the very first layer's all-gather (or the first backward layer) may stall, and only if the GPU compute is faster than the network. + + +TODO: add scalability perf of Ultra in nvfp4. + +--- + +## Schedule Details + +### Forward Pass + +``` +Layer i-1 fwd Layer i fwd Layer i+1 fwd +┌─────────────────────────┐ ┌─────────────────────────┐ ┌────────────── +│ all_gather_and_prefetch │ │ all_gather_and_prefetch │ │ ... +│ ├─ get W_i-1 (cached) │ │ ├─ get W_i (cached) │ │ +│ └─ async AG W_i ───── │─────▶ ready at use time │ │ +│ │ │ │ │ +│ GEMM(input, W_i-1) │ │ GEMM(input, W_i) │ │ +└─────────────────────────┘ └─────────────────────────┘ └────────────── + ↑ Overlap ↑ Overlap + AG(W_i) ∥ GEMM(W_i-1) AG(W_i+1) ∥ GEMM(W_i) +``` + +Step by step for layer `i`: + +1. **Lazy linked-list construction** (first pass only): Each `ETPShardedParam` has a `prefetch_initialized` flag. On the first call to `all_gather_and_prefetch`, this flag is `False`. The weight links itself to the previous weight (`cls._last_weight`) by setting `prev_w` / `next_w`, then sets `prefetch_initialized = True`. On subsequent passes the linking block is skipped. The complete link table is buffered during the first pass and flushed atomically as a single log print at the start of the second pass. +2. **Retrieve current weight**: + - `prev_w is not None` and `_ag_ticket_fwd is not None` → pull from buffer cache (ticket already reserved by async prefetch) + - otherwise → synchronous on-demand all-gather (only on very first use or when prefetch is disabled) +3. **Quantize if needed** (FP8/NVFP4/MXFP8): re-quantize the local shard into its pre-allocated quantized buffer before communication. +4. **Run GEMM** using the gathered full weight. +5. **Async prefetch next weight**: kick off `_all_gather_weight(async_op=True)` for `next_w` and store the handle in `next_w._prefetch_handle`. +6. **Release buffer**: after returning the gathered weight to the caller, the buffer for the current weight is returned to the pool via `cache.release(ticket)`. +7. **Save sharded weight** (not gathered) for the backward pass: `weight_etp_sharded` is stored in `ctx`; the gathered buffer is transient. + +#### Prefetch implementation sketch + +```python +# all_gather_and_prefetch (simplified) +if self.prev_w is not None and self._ag_ticket_fwd is not None: + result = self._get_prefetched_weight(fwd=True, ...) # cached +else: + result = self._all_gather_weight_on_demand(fwd=True, ...) # sync fallback + +if ETP_CONFIG.weight_prefetch and self.next_w is not None: + _, handle = self.next_w._all_gather_weight(async_op=True, ...) + self.next_w._prefetch_handle = handle + +if self.prev_w is not None: + cache.release(self._ag_ticket_fwd) # return consumed buffer to pool + +# First-pass only: link into prefetch chain +if not self.prefetch_initialized: + if cls._last_weight is not None and cls._last_weight.next_w is None: + cls._buffer_link_table_row(cls._last_weight, self) + cls._last_weight.next_w = self + self.prev_w = cls._last_weight + self.prefetch_initialized = True +elif not cls._link_table_flushed and cls._link_table_buffer: + cls._link_table_flushed = True + print_rank_0("\n".join(cls._link_table_buffer) + "\n") # atomic flush +cls._last_weight = self +``` + +The all-gather for layer `i+1` runs on the dedicated `AG_STREAM` while the GEMM for layer `i` runs on the compute stream, giving near-perfect overlap for GPU-compute-bound models. Similarly, the wgrad reduce-scatter runs on `RS_STREAM`. Both streams signal completion via CUDA events (`ag_event`, `rs_event`) that are waited on the compute stream before the result is consumed, ensuring correct ordering without blocking either communication stream. + +### Backward Pass + +The backward schedule mirrors forward, but traverses the layer chain in reverse: + +``` +Layer i+1 bwd Layer i bwd Layer i-1 bwd +┌─────────────────────────┐ ┌─────────────────────────┐ ┌────────────── +│ all_gather_and_prefetch │ │ all_gather_and_prefetch │ │ ... +│ ├─ get W_i+1 (cached) │ │ ├─ get W_i (cached) │ │ +│ └─ async AG W_i ────── │─────▶ ready at use time │ │ +│ │ │ │ │ +│ dgrad GEMM(grad, W_i+1) │ │ dgrad GEMM(grad, W_i) │ │ +│ wgrad GEMM(act, grad) │ │ wgrad GEMM(act, grad) │ │ +│ async RS(wgrad_i+1) ─── │─────▶ finish RS before use │ │ +└─────────────────────────┘ └─────────────────────────┘ └────────────── +``` + +Step by step for layer `i` backward: + +1. **`all_gather_and_prefetch_bwd()`**: Gather `W_i` for the dgrad GEMM; simultaneously async-prefetch `W_i-1` (the `prev_w`) for the next backward step. Uses `skip_weight_cast=True` — no re-quantization needed since scales are already valid from the forward pass. +2. **dgrad GEMM**: Compute `dX = dY × W_i` using the gathered weight. +3. **wgrad GEMM**: Compute `dW = X^T × dY` using the saved input activation. +4. **`wgrad_reduce_scatter(wgrad)`**: + - **Non-last layer** (`prev_w is not None`): Launch async reduce-scatter; store `ETPShardHandle` in `self._wgrad_rs_handle`. Return `None` to backward (gradient deferred). + - **Last layer** (`prev_w is None`): Synchronous reduce-scatter. Call `_finalize_wgrad()` immediately — resets `rs_state` to `NONE`, strips padding (if last rank is padded), accumulates into `param.main_grad`, returns a dummy-zero grad tensor to autograd. +5. **Deferred finish**: At the start of each subsequent layer's `wgrad_reduce_scatter`, `self.next_w._wait_reduce_scatter()` is called, which waits on `next_w._wgrad_rs_handle` and records a CUDA event. Then `_finalize_wgrad()` is called for `next_w` to reset `rs_state`, strip padding, and accumulate into `main_grad`. The RS buffer is returned to the pool via `cache.release()`. + + +Here is an example of ETP schedule diagram for Hybried Nemotron6 in bf16 as an example (ETP+EP with partial CGs): + +![alt text](etp/etp_ep_nt6_schedule_bf16.png) + + +### Coalesced Expert Communication + +For MoE layers with multiple routed experts, all experts' all-gathers are coalesced into a single NCCL operation via `torch.distributed._coalescing_manager`. This reduces NCCL kernel launch overhead and improves bus utilization compared to E sequential all-gathers. The wgrad reduce-scatter for all experts is similarly coalesced. + +--- + +## Low-Precision Details + +### FP8 (per-tensor scaling) + +- Each `ETPShardedParam` is assigned a quantizer via `setup(weight_quantizer)`. +- The quantizer is configured with `amax_reduction_group=etp_group` (the group is already stored in the param from construction), so the amax is all-reduced across the ETP group before scaling—ensuring all GPUs in the group use the same scale factor for the full weight. +- On the first microbatch (`is_first_microbatch=True`), `_quantize_if_needed()` re-quantizes the shard. On subsequent microbatches, `skip_weight_cast=True` reuses the existing quantized buffer, saving re-quantization cost. +- A `cast_noop_flag` tensor (from the FP8 recipe) can signal that no scale update is needed, enabling a no-op cast path. + +### NVFP4 (4-bit, block-scaled) + +NVFP4 requires special communication handling because: +- Each 4-bit value shares a scale with its 16-element block. +- The layout has both rowwise and columnwise views, each with separate data and `scale_inv` tensors. +- After all-gather, the interleaved format must be re-assembled into a GEMM-ready layout. + +The `_all_gather_nvfp4()` function in `distributed.py` handles this: +1. **Pre-communication**: Strips padding from `scale_inv` tensors (padding ensures alignment to communication boundaries). +2. **All-gather**: Gathers both `data` and `scale_inv` for the rowwise view; similarly for the columnwise view (with transposed tensor handling). +3. **Post-processing** (`_post_process_nvfp4_gather` / `post_process_nvfp4_gather`): + - Fixes interleaved data layout back to packed format. + - Re-pads `scale_inv` to the GEMM-required alignment. + - Transitions the tensor to `GEMM_READY` state. + +For async all-gathers, post-processing is deferred into `_NVFP4AllGatherAsyncHandle.wait()`, keeping it off the critical path. + +For routed experts, `BatchedNVFP4AllGatherAsyncHandle` wraps one handle per expert; the single outer coalescing-manager handle is waited first, then each expert's NVFP4 post-processing is applied sequentially. + +`_strip_padding` handles NVFP4 scale_inv correctly: +- `rowwise_scale_inv`: strip to `round_up(M, 128)` rows (dim 0) +- `columnwise_scale_inv`: strip to `round_up(ceil(M / 16), 4)` columns (dim 1, transposed) + +### MXFP8 (microscaling FP8) + +MXFP8 follows the same quantize-then-gather pattern as FP8. The amax reduction for microscaling is handled within the quantizer; ETP configures the reduction group to be the ETP group. + +`_strip_padding` handles MXFP8 scale_inv correctly: +- `rowwise_scale_inv`: strip to `round_up(M, 128)` rows (dim 0) +- `columnwise_scale_inv`: strip to `round_up(M // 32, 4)` rows (dim 0; columnwise is not transposed for MXFP8) + +### Bandwidth Savings from Quantization + +| Dtype | Size vs BF16 | Example: 8B param weight | +|---|---|---| +| BF16 | 1× | 16 GB per ETP group | +| FP8 | 0.5× | 8 GB | +| NVFP4 | 0.25× | 4 GB | + +With ETP size N=8 and NVFP4, each GPU holds and gathers 0.5 GB instead of the full 16 GB. + +--- + +## Memory Savings + +### Per-GPU Weight Memory + +With ETP group size N, each GPU stores only `1/N` of each weight at rest. The gathered weight is transient (lives only during the GEMM) and reused from the pool. + +### Ticket-Based Buffer Pool + +`ETPWeightCache` pools gathered weight buffers by `(shape, dtype, fwd, expert_idx, reduce_scatter)` key so that same-shaped weights across layers reuse a single GPU allocation instead of allocating per-layer. + +#### Data structures + +``` +_pool : { cache_key → [buf, buf, ...] } available (released) buffers +_slots : { ticket_id → _TicketSlot } persistent per-param ticket slots + (key, param, dtype, fwd, reduce_scatter, buf) +_next_ticket : int monotonically increasing ticket ID counter +``` + +Each `ETPShardedParam` holds up to three tickets: +- `_ag_ticket_fwd` — forward all-gather buffer +- `_ag_ticket_bwd` — backward all-gather buffer +- `_rs_ticket` — reduce-scatter buffer + +A buffer lives in **exactly one** place at a time: + +``` +reserve() → slot created, buf=None (no allocation yet) +get(ticket) → buf allocated lazily from pool or fresh; stored in slot.buf (idempotent) +release(ticket) → buf appended to pool (slot.buf stays set; production code calls release + only after get() has emptied the pool for that key, so the duplicate-check + in release() is never triggered) +clear() → all slot.buf = None, pool cleared (tickets stay valid; next get() re-allocates) +``` + +#### CUDA Graph support + +Before graph capture, call `reallocate_etp_cache_to_mempool(device, mempool)` to migrate all pool buffers into the CUDA graph memory pool. This ensures no allocations occur inside the captured graph. + +### No Activation Duplication + +The sharded weight (`weight_etp_sharded`) is saved for the backward pass instead of the gathered weight. This avoids keeping a full-size weight copy in the gradient tape, which would negate the memory savings. + +### Quantized Shard Storage + +When using FP8/NVFP4/MXFP8, only the quantized shard (not BF16) is stored persistently in `ETPShardedParam.quantized`. The full-precision master weight can reside in the optimizer state on CPU or be managed separately, keeping GPU footprint at quantized shard size. + +--- + +## API Usage + +
+Click to expand + +```python +import torch.distributed as dist +from transformer_engine.pytorch import Linear, LayerNormLinear, LayerNormMLP +from transformer_engine.pytorch.module.extended_tensor_parallelism import ( + tag_etp_params_with_names, + update_config, +) + +# Set up process groups +tp_group = ... # Tensor-parallel group +etp_group = ... # ETP group (orthogonal to TP) + +# Drop-in replacement for standard TE Linear (dense model) +# Weights are sharded at construction time by wrap_module_params_etp +layer = Linear( + in_features=4096, + out_features=4096, + parallel_mode="column", # or "row" + tp_group=tp_group, + etp_group=etp_group, # Enable ETP +) + +# Also works with LayerNormLinear and LayerNormMLP (dense or MoE feed-forward) +ffn = LayerNormMLP( + hidden_size=4096, + ffn_hidden_size=16384, + tp_group=tp_group, + etp_group=etp_group, # Enable ETP +) + +# Weight is automatically an ETPShardedParam holding only the local shard +assert isinstance(layer.weight, ETPShardedParam) + +# Call setup() once after constructing quantizers (FP8/NVFP4). +# Note: etp_group is already stored in the param; setup() only takes quantizers. +layer.weight.setup(weight_quantizer=quantizers) + +# Optionally tag all ETP params with human-readable names for the link table log. +# Call once after full model construction. +tag_etp_params_with_names(model) + +# Forward/backward are transparent — ETP handles all-gather/reduce-scatter internally +output = layer(input) +``` + +
+ +For MoE layers with routed experts, `GroupedLinear` uses the same `etp_group` argument and handles batched expert weight gathers automatically. + +--- + +## Implementation Files + +| File | Role | +|---|---| +| `transformer_engine/pytorch/module/extended_tensor_parallelism.py` | Core ETP: `ETPShardedParam`, `ETPWeightCache`, `_TicketSlot`, `ETPWeightState`, `ETPConfig`, `wrap_module_params_etp`, `tag_etp_params_with_names`, `update_config`, `reallocate_etp_cache_to_mempool`, `wait_async_comms` | +| `transformer_engine/pytorch/module/linear.py` | ETP integration in `Linear` forward/backward | +| `transformer_engine/pytorch/module/layernorm_linear.py` | ETP integration in `LayerNormLinear` forward/backward | +| `transformer_engine/pytorch/module/layernorm_mlp.py` | ETP integration in `LayerNormMLP` forward/backward | +| `transformer_engine/pytorch/module/grouped_linear.py` | ETP integration for MoE routed-expert grouped GEMMs | +| `transformer_engine/pytorch/distributed.py` | `gather_along_first_dim`, `_all_gather_nvfp4`, `_NVFP4AllGatherAsyncHandle` | +| `tests/pytorch/distributed/test_etp.py` | ETP unit tests: state machine, buffer cache, weight sharding, module param replacement, `Linear`/`LayerNormLinear`/`GroupedLinear` fwd/bwd correctness, prefetch chain, wgrad reduce-scatter, microbatches, NVFP4 fwd/bwd (aligned + unaligned), MXFP8 fwd/bwd (aligned + unaligned) | +| `tests/pytorch/distributed/test_tp_etp.py` | TP+ETP integration tests: process group layout, `Linear` (column/row parallel) weight shape and fwd/bwd correctness, `LayerNormLinear` and `LayerNormMLP` fwd/bwd smoke tests; runs on 4 GPUs with TP=2, ETP=2 | + +---- + +## Best Practice + +TODO + +---- + +## Caveats + +- First forward pass always stalls (cold start) + + On the very first forward pass, `state == NONE` for all weights (no prefetch has run yet), so every weight does a synchronous all-gather. Only from the second pass onward does the async prefetch chain kick in. For frameworks that benchmark the first iteration (e.g., profilers, compilation warmup), this cold-start stall looks like a regression. + +- Link table logged on second forward pass + + The prefetch-link table (printed via `tag_etp_params_with_names` + the built-in logging) is buffered during the first forward pass and flushed atomically at the start of the second forward pass. This ensures it is not interleaved with other logs, but means it will not appear until the second iteration. + +---- + +## Future Work + +TODO + +---- diff --git a/docs/etp/etp_ep_nt6_schedule_bf16.png b/docs/etp/etp_ep_nt6_schedule_bf16.png new file mode 100644 index 0000000000..828f7b14a6 Binary files /dev/null and b/docs/etp/etp_ep_nt6_schedule_bf16.png differ diff --git a/tests/pytorch/distributed/test_etp.py b/tests/pytorch/distributed/test_etp.py new file mode 100644 index 0000000000..39afe69b00 --- /dev/null +++ b/tests/pytorch/distributed/test_etp.py @@ -0,0 +1,1411 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Unit tests for Extended Tensor Parallelism (ETP). + +Test groups +----------- +1. TestETPWeightState – state-machine transitions (single-process) +2. TestETPWeightCache – coat-check buffer pool (single-process) +3. TestETPSharding – wrap_module_params_etp: shard content + padding (multi-GPU) +4. TestWrapModuleParams – wrap_module_params_etp: param replacement + weight_list (multi-GPU) +5. TestLinearETP – Linear forward/backward numerical correctness (multi-GPU) +6. TestLayerNormLinearETP – LayerNormLinear forward/backward smoke test (multi-GPU) +7. TestGroupedLinearETP – GroupedLinear forward/backward smoke test (multi-GPU) +8. TestETPPrefetchChain – linked-list next_w/prev_w wiring (multi-GPU) +9. TestETPWgradRS – wgrad reduce-scatter shape + multi-layer deferred path (multi-GPU) +10. TestETPMicrobatches – output consistency across microbatches (multi-GPU) +11. TestNVFP4LinearETP – Linear + NVFP4 recipe: quantized shard setup, fwd/bwd (multi-GPU) +12. TestNVFP4GroupedLinearETP – GroupedLinear + NVFP4 recipe: coalesced AG + fwd/bwd (multi-GPU) +13. TestMXFP8LinearETP – Linear + MXFP8 recipe: quantized shard setup, fwd/bwd, padding (multi-GPU) +14. TestETPConfig – update_config: valid/invalid keys (single-process) +15. TestETPShardedParamProperties – shape computations, get_padded_shard, _strip_padding (single-process) +16. TestETPCacheKey – _get_cache_key: expert vs non-expert, fwd vs bwd (single-process) +17. TestETPCacheRelease – reserve/get/release pool semantics (single-process) +18. TestTagETPParamsWithNames – _debug_name population on ETPShardedParam (single-process) +19. TestFinalizeWgrad – _finalize_wgrad: accumulate, strip padding, rs_state reset (single-process) +20. TestETPGroupSizeOne – wrap_module_params_etp no-op when etp_group.size()==1 (single-process) +21. TestETPPrefetchDisabled – weight_prefetch=False: single-pass forward still works (multi-GPU) +22. TestFuseWgradAccumulation – fuse_wgrad_accumulation=True: wgrad→main_grad (multi-GPU) +23. TestETPGradAccumHook – main_grad updated after reduce-scatter backward (multi-GPU) + +Multi-GPU tests use torch.multiprocessing.spawn and are skipped when fewer +than the required CUDA devices are available. +""" + +import os +import socket + +import pytest +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +import torch.nn as nn + +import transformer_engine.pytorch as te +import transformer_engine.pytorch.module.extended_tensor_parallelism as etp_module +from transformer_engine.pytorch.module.extended_tensor_parallelism import ( + ETPShardedParam, + ETPWeightCache, + ETPWeightState, + wrap_module_params_etp, +) +from transformer_engine.pytorch import fp8_autocast, is_nvfp4_available, is_mxfp8_available +from transformer_engine.pytorch.quantization import FP8GlobalStateManager +from transformer_engine.pytorch.quantized_tensor import QuantizedTensor +from transformer_engine.common.recipe import NVFP4BlockScaling + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + +@pytest.fixture(autouse=True) +def reset_fp8_state(): + yield + FP8GlobalStateManager.reset() + + +@pytest.fixture(autouse=True) +def reset_etp_globals(): + """Reset all ETP mutable class/module-level state between tests.""" + yield + ETPShardedParam._first_weight_flag = True + ETPShardedParam._last_weight = None + ETPShardedParam._pending_rs_weight = None + ETPShardedParam._link_node_count = 0 + ETPShardedParam._link_table_buffer = [] + ETPShardedParam._link_table_flushed = False + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _free_port() -> int: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", 0)) + return s.getsockname()[1] + + +def _dist_init(rank: int, world_size: int, port: int) -> None: + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = str(port) + dist.init_process_group(backend="nccl", rank=rank, world_size=world_size) + torch.cuda.set_device(rank) + + +def _run_distributed(fn, world_size: int, *args) -> None: + """Spawn `world_size` processes each running fn(rank, world_size, port, *args).""" + port = _free_port() + mp.spawn(fn, args=(world_size, port) + args, nprocs=world_size, join=True) + + +def _requires_multi_gpu(n: int = 4): + if torch.cuda.device_count() < n: + pytest.skip(f"Requires at least {n} CUDA devices") + + +def _requires_nvfp4(): + if not is_nvfp4_available(): + pytest.skip("NVFP4 not available (requires compute capability >= 10.0)") + + +# --------------------------------------------------------------------------- +# 1. ETPWeightState – state-machine transition tests +# --------------------------------------------------------------------------- + +class TestETPWeightState: + + @staticmethod + def _param(): + return ETPShardedParam(torch.zeros(4, 4)) + + def test_full_cycle(self): + p = self._param() + assert p.state == ETPWeightState.NONE + p._set_state(ETPWeightState.ASYNC_WAIT) + p._set_state(ETPWeightState.DATA_READY) + p._set_state(ETPWeightState.NONE) + assert p.state == ETPWeightState.NONE + + def test_sync_path_cycle(self): + """NONE → DATA_READY_SYNC → NONE (sync all-gather path).""" + p = self._param() + p._set_state(ETPWeightState.DATA_READY_SYNC) + p._set_state(ETPWeightState.NONE) + assert p.state == ETPWeightState.NONE + + def test_rs_state_full_cycle(self): + """RS state machine: NONE → ASYNC_WAIT → DATA_READY → NONE.""" + p = self._param() + assert p.rs_state == ETPWeightState.NONE + p._set_rs_state(ETPWeightState.ASYNC_WAIT) + p._set_rs_state(ETPWeightState.DATA_READY) + p._set_rs_state(ETPWeightState.NONE) + assert p.rs_state == ETPWeightState.NONE + + +# --------------------------------------------------------------------------- +# 2. ETPWeightCache – coat-check buffer pool tests +# --------------------------------------------------------------------------- + +class TestETPWeightCache: + + class _FakeGroup: + def __init__(self, size=2): + self._size = size + def size(self): + return self._size + def rank(self): + return 0 + + def _param(self, shape=(8, 4), etp_size=2): + p = ETPShardedParam(torch.zeros(*shape)) + p.group = self._FakeGroup(etp_size) + p.expert_idx = None + p.pad_length = 0 + p.is_padded_last_rank = False + p._quantizer = None + return p + + def test_reserve_returns_ticket(self): + cache = ETPWeightCache() + p = self._param() + ticket = cache.reserve(p, torch.bfloat16, fwd=True) + assert isinstance(ticket, int) + + def test_reserve_get_roundtrip(self): + cache = ETPWeightCache() + p = self._param() + ticket = cache.reserve(p, torch.bfloat16, fwd=True) + buf = cache.get(ticket) + assert buf is not None + # get() returns same buf on second call (buf cached in slot) + buf2 = cache.get(ticket) + assert buf2 is buf + + def test_buffer_reused_after_release(self): + cache = ETPWeightCache() + p = self._param() + t1 = cache.reserve(p, torch.bfloat16, fwd=True) + buf1 = cache.get(t1) + cache.release(t1) + # Reserve a new ticket, buf should come from pool + t2 = cache.reserve(p, torch.bfloat16, fwd=True) + buf2 = cache.get(t2) + assert buf1 is buf2, "Buffer should be reused from pool after release" + cache.release(t2) + + def test_two_simultaneous_reserves_are_distinct(self): + cache = ETPWeightCache() + p = self._param() + t1 = cache.reserve(p, torch.bfloat16, fwd=True) + buf1 = cache.get(t1) + t2 = cache.reserve(p, torch.bfloat16, fwd=True) + buf2 = cache.get(t2) + assert buf1 is not buf2, "Concurrent reserves must get distinct buffers" + + def test_tickets_are_unique(self): + """Each reserve() call returns a new unique ticket.""" + cache = ETPWeightCache() + p = self._param() + t1 = cache.reserve(p, torch.bfloat16, fwd=True) + t2 = cache.reserve(p, torch.bfloat16, fwd=True) + assert t1 != t2, "Each reserve() must return a unique ticket" + + def test_invalid_ticket_raises(self): + cache = ETPWeightCache() + with pytest.raises(KeyError): + cache.get(9999) + + def test_different_shapes_use_distinct_pool_slots(self): + cache = ETPWeightCache() + p1 = self._param(shape=(8, 4)) + p2 = self._param(shape=(16, 4)) + t1 = cache.reserve(p1, torch.bfloat16, fwd=True) + buf1 = cache.get(t1) + t2 = cache.reserve(p2, torch.bfloat16, fwd=True) + buf2 = cache.get(t2) + assert buf1.shape != buf2.shape + cache.release(t1); cache.release(t2) + + def test_fwd_bwd_tickets_are_distinct(self): + """fwd=True and fwd=False reserves always receive distinct ticket IDs.""" + cache = ETPWeightCache() + p = self._param() + t_fwd = cache.reserve(p, torch.bfloat16, fwd=True) + t_bwd = cache.reserve(p, torch.bfloat16, fwd=False) + assert t_fwd != t_bwd + + +# --------------------------------------------------------------------------- +# 3. ETP weight sharding: shard content and alignment padding +# --------------------------------------------------------------------------- + +def _worker_sharding_aligned(rank, world_size, port): + _dist_init(rank, world_size, port) + K, M = world_size * 32, 16 # K divisible by 16*world_size → no padding + full_weight = torch.arange(K * M, dtype=torch.float32).reshape(K, M).cuda() + dist.broadcast(full_weight, src=0) + + etp_group = dist.new_group(list(range(world_size))) + mod = nn.Module() + mod.weight = nn.Parameter(full_weight.clone(), requires_grad=False) + wrap_module_params_etp(mod, ['weight'], etp_group) + shard = mod.weight + + rows_per_rank = K // world_size + assert shard.shape == (rows_per_rank, M), f"rank {rank}: unexpected shape {shard.shape}" + assert shard.pad_length == 0 + expected = full_weight[rank * rows_per_rank : (rank + 1) * rows_per_rank] + assert torch.allclose(shard.data, expected), f"rank {rank}: shard content mismatch" + dist.destroy_process_group() + + +def _worker_sharding_padding(rank, world_size, port): + _dist_init(rank, world_size, port) + alignment = 16 * world_size + K = alignment - 1 # deliberately unaligned + M = 16 + full_weight = torch.ones(K, M, dtype=torch.float32).cuda() + dist.broadcast(full_weight, src=0) + + etp_group = dist.new_group(list(range(world_size))) + mod = nn.Module() + mod.weight = nn.Parameter(full_weight.clone(), requires_grad=False) + wrap_module_params_etp(mod, ['weight'], etp_group) + shard = mod.weight + + padded_K = alignment + rows_per_rank = padded_K // world_size + + if rank == world_size - 1: + assert shard.is_padded_last_rank + assert shard.pad_length > 0 + # The shard tensor holds only the real rows; get_padded_shard() appends zero rows. + padded = shard.get_padded_shard() + assert padded.shape[0] == rows_per_rank, \ + f"rank {rank}: expected padded shard {rows_per_rank} rows, got {padded.shape[0]}" + n_real = K - rank * rows_per_rank + assert torch.all(padded[n_real:] == 0), "Padding rows must be zero" + else: + assert not shard.is_padded_last_rank + assert shard.shape[0] == rows_per_rank, \ + f"rank {rank}: expected {rows_per_rank} rows, got {shard.shape[0]}" + + dist.destroy_process_group() + + +class TestETPSharding: + def test_aligned_shard_content(self): + _requires_multi_gpu(4) + _run_distributed(_worker_sharding_aligned, 4) + + def test_unaligned_shard_padding(self): + _requires_multi_gpu(4) + _run_distributed(_worker_sharding_padding, 4) + + +# --------------------------------------------------------------------------- +# 4. wrap_module_params_etp: param replacement and GroupedLinear weight_list +# --------------------------------------------------------------------------- + +def _worker_linear_param_replaced(rank, world_size, port): + _dist_init(rank, world_size, port) + in_f, out_f = 64, 128 + etp_group = dist.new_group(list(range(world_size))) + layer = te.Linear( + in_features=in_f, out_features=out_f, + bias=False, params_dtype=torch.bfloat16, + device="cuda", etp_group=etp_group, + ) + w = layer.weight + assert isinstance(w, ETPShardedParam), "weight must be ETPShardedParam" + assert w.shape == (out_f // world_size, in_f), f"unexpected shard shape {w.shape}" + assert w.group is etp_group + dist.destroy_process_group() + + +def _worker_grouped_weight_list(rank, world_size, port): + _dist_init(rank, world_size, port) + num_gemms, in_f, out_f = 3, 32, 64 + etp_group = dist.new_group(list(range(world_size))) + layer = te.GroupedLinear( + num_gemms=num_gemms, in_features=in_f, out_features=out_f, + bias=False, params_dtype=torch.bfloat16, + device="cuda", etp_group=etp_group, + ) + w0 = layer.weight0 + assert isinstance(w0, ETPShardedParam) + assert w0.weight_list is not None + assert len(w0.weight_list) == num_gemms + assert [w.expert_idx for w in w0.weight_list] == list(range(num_gemms)) + dist.destroy_process_group() + + +class TestWrapModuleParams: + def test_linear_weight_replaced(self): + _requires_multi_gpu(4) + _run_distributed(_worker_linear_param_replaced, 4) + + def test_grouped_linear_weight_list(self): + _requires_multi_gpu(4) + _run_distributed(_worker_grouped_weight_list, 4) + + +# --------------------------------------------------------------------------- +# 5. Linear forward/backward numerical correctness +# --------------------------------------------------------------------------- + +def _worker_linear_correctness(rank, world_size, port): + """ETP output == (all-gathered weight) @ input, and dX matches.""" + _dist_init(rank, world_size, port) + torch.manual_seed(0) + batch, in_f, out_f = 16, 64, 128 # out_f % (16*world_size)==0 → no padding + dtype = torch.bfloat16 + etp_group = dist.new_group(list(range(world_size))) + + layer = te.Linear( + in_features=in_f, out_features=out_f, + bias=False, params_dtype=dtype, + device="cuda", etp_group=etp_group, + ) + + # Reconstruct full weight from shards (all-gather) + shard = layer.weight.data.clone() + all_shards = [torch.zeros_like(shard) for _ in range(world_size)] + dist.all_gather(all_shards, shard, group=etp_group) + full_weight = torch.cat(all_shards, dim=0).float()[:out_f] # strip any padding + + # Shared input across ranks + inp = torch.randn(batch, in_f, dtype=dtype, device="cuda") + dist.broadcast(inp, src=0) + + inp_etp = inp.clone().requires_grad_(True) + inp_ref = inp.clone().requires_grad_(True) + + # ETP forward + out_etp = layer(inp_etp, is_first_microbatch=True) + + # Reference forward + out_ref = inp_ref.float() @ full_weight.T + out_ref = out_ref.to(dtype) + + assert out_etp.shape == out_ref.shape, f"Shape mismatch {out_etp.shape} vs {out_ref.shape}" + assert torch.allclose(out_etp.float(), out_ref.float(), atol=0.1, rtol=0.1), ( + f"Output mismatch max_diff={(out_etp.float()-out_ref.float()).abs().max():.4f}" + ) + + # _finalize_wgrad always accumulates into main_grad; allocate before backward. + layer.weight.main_grad = torch.zeros(layer.weight.shape, dtype=dtype, device="cuda") + + # Backward: compare input gradient + grad_out = torch.randn_like(out_etp) + dist.broadcast(grad_out, src=0) + out_etp.backward(grad_out) + out_ref.backward(grad_out.float()) + + assert inp_etp.grad is not None + assert torch.allclose(inp_etp.grad.float(), inp_ref.grad.float(), atol=0.1, rtol=0.1), ( + f"dX mismatch max_diff={(inp_etp.grad.float()-inp_ref.grad.float()).abs().max():.4f}" + ) + dist.destroy_process_group() + + +class TestLinearETP: + def test_forward_backward_correctness(self): + _requires_multi_gpu(4) + _run_distributed(_worker_linear_correctness, 4) + + +# --------------------------------------------------------------------------- +# 6. LayerNormLinear forward/backward smoke test +# --------------------------------------------------------------------------- + +def _worker_layernorm_linear(rank, world_size, port): + _dist_init(rank, world_size, port) + torch.manual_seed(0) + seq, batch, in_f, out_f = 4, 2, 64, 128 + dtype = torch.bfloat16 + etp_group = dist.new_group(list(range(world_size))) + + layer = te.LayerNormLinear( + in_features=in_f, out_features=out_f, + bias=False, params_dtype=dtype, + device="cuda", etp_group=etp_group, + ) + assert isinstance(layer.weight, ETPShardedParam) + + inp = torch.randn(seq, batch, in_f, dtype=dtype, device="cuda", requires_grad=True) + dist.broadcast(inp, src=0) + + out = layer(inp, is_first_microbatch=True) + assert out.shape == (seq, batch, out_f), f"unexpected output shape {out.shape}" + + layer.weight.main_grad = torch.zeros(layer.weight.shape, dtype=dtype, device="cuda") + out.sum().backward() + assert inp.grad is not None and inp.grad.shape == inp.shape + dist.destroy_process_group() + + +class TestLayerNormLinearETP: + def test_forward_backward(self): + _requires_multi_gpu(4) + _run_distributed(_worker_layernorm_linear, 4) + + +# --------------------------------------------------------------------------- +# 7. GroupedLinear forward/backward smoke test +# --------------------------------------------------------------------------- + +def _worker_grouped_linear(rank, world_size, port, num_gemms): + _dist_init(rank, world_size, port) + torch.manual_seed(0) + in_f, out_f, total_tokens = 32, 64, num_gemms * 4 + dtype = torch.bfloat16 + etp_group = dist.new_group(list(range(world_size))) + + layer = te.GroupedLinear( + num_gemms=num_gemms, in_features=in_f, out_features=out_f, + bias=False, params_dtype=dtype, + device="cuda", etp_group=etp_group, + ) + assert isinstance(layer.weight0, ETPShardedParam) + + m_splits = [total_tokens // num_gemms] * num_gemms + m_splits[-1] += total_tokens - sum(m_splits) + + inp = torch.randn(total_tokens, in_f, dtype=dtype, device="cuda", requires_grad=True) + dist.broadcast(inp, src=0) + + out = layer(inp, m_splits=m_splits, is_first_microbatch=True) + assert out.shape == (total_tokens, out_f), f"unexpected output shape {out.shape}" + + for i in range(num_gemms): + w = getattr(layer, f"weight{i}") + w.main_grad = torch.zeros(w.shape, dtype=dtype, device="cuda") + out.sum().backward() + assert inp.grad is not None and inp.grad.shape == inp.shape + dist.destroy_process_group() + + +class TestGroupedLinearETP: + @pytest.mark.parametrize("num_gemms", [2, 4]) + def test_forward_backward(self, num_gemms): + _requires_multi_gpu(4) + _run_distributed(_worker_grouped_linear, 4, num_gemms) + + +# --------------------------------------------------------------------------- +# 8. Prefetch chain: next_w / prev_w wiring after first forward pass +# --------------------------------------------------------------------------- + +def _worker_chain_wired(rank, world_size, port): + _dist_init(rank, world_size, port) + torch.manual_seed(0) + in_f, out_f = 32, 64 + dtype = torch.bfloat16 + etp_group = dist.new_group(list(range(world_size))) + + l0 = te.Linear(in_features=in_f, out_features=out_f, bias=False, + params_dtype=dtype, device="cuda", etp_group=etp_group) + l1 = te.Linear(in_features=in_f, out_features=out_f, bias=False, + params_dtype=dtype, device="cuda", etp_group=etp_group) + + inp = torch.randn(4, in_f, dtype=dtype, device="cuda") + dist.broadcast(inp, src=0) + + # First forward pass builds the linked list + l0(inp, is_first_microbatch=True) + l1(inp, is_first_microbatch=True) + + w0, w1 = l0.weight, l1.weight + assert w0.next_w is w1, "w0.next_w should point to w1" + assert w1.prev_w is w0, "w1.prev_w should point back to w0" + assert w1.next_w is None + assert w0.prev_w is None + dist.destroy_process_group() + + +def _worker_chain_async_prefetch(rank, world_size, port): + """On the second forward pass, w1 should be in DATA_READY before its forward runs.""" + _dist_init(rank, world_size, port) + torch.manual_seed(0) + in_f, out_f = 32, 64 + dtype = torch.bfloat16 + etp_group = dist.new_group(list(range(world_size))) + + l0 = te.Linear(in_features=in_f, out_features=out_f, bias=False, + params_dtype=dtype, device="cuda", etp_group=etp_group) + l1 = te.Linear(in_features=in_f, out_features=out_f, bias=False, + params_dtype=dtype, device="cuda", etp_group=etp_group) + + inp = torch.randn(4, in_f, dtype=dtype, device="cuda") + dist.broadcast(inp, src=0) + + # First pass builds chain, second pass uses async prefetch + for _ in range(2): + out = l0(inp, is_first_microbatch=True) + l1(inp, is_first_microbatch=True) + assert torch.isfinite(out).all(), "Non-finite output on second pass" + dist.destroy_process_group() + + +class TestETPPrefetchChain: + def test_chain_wired_after_first_pass(self): + _requires_multi_gpu(4) + _run_distributed(_worker_chain_wired, 4) + + def test_async_prefetch_second_pass(self): + _requires_multi_gpu(4) + _run_distributed(_worker_chain_async_prefetch, 4) + + +# --------------------------------------------------------------------------- +# 9. Wgrad reduce-scatter: shape and deferred async path +# --------------------------------------------------------------------------- + +def _worker_wgrad_shape(rank, world_size, port): + """After backward, weight.grad shape must match the local shard shape.""" + _dist_init(rank, world_size, port) + torch.manual_seed(0) + in_f, out_f = 32, 64 + dtype = torch.bfloat16 + etp_group = dist.new_group(list(range(world_size))) + + layer = te.Linear( + in_features=in_f, out_features=out_f, + bias=False, params_dtype=dtype, + device="cuda", etp_group=etp_group, + fuse_wgrad_accumulation=False, + ) + inp = torch.randn(8, in_f, dtype=dtype, device="cuda", requires_grad=True) + dist.broadcast(inp, src=0) + + layer.weight.main_grad = torch.zeros(layer.weight.shape, dtype=dtype, device="cuda") + layer(inp, is_first_microbatch=True).sum().backward() + + w = layer.weight + if w.grad is not None: + assert w.grad.shape == w.shape, \ + f"wgrad shape {w.grad.shape} != shard shape {w.shape}" + dist.destroy_process_group() + + +def _worker_multilayer_deferred_rs(rank, world_size, port): + """Two-layer ETP: async RS deferred for layer0 (non-last), sync for layer1 (last in bwd).""" + _dist_init(rank, world_size, port) + torch.manual_seed(0) + in_f, out_f = 32, 64 + dtype = torch.bfloat16 + etp_group = dist.new_group(list(range(world_size))) + + l0 = te.Linear(in_features=in_f, out_features=out_f, bias=False, + params_dtype=dtype, device="cuda", etp_group=etp_group) + l1 = te.Linear(in_features=in_f, out_features=out_f, bias=False, + params_dtype=dtype, device="cuda", etp_group=etp_group) + + inp = torch.randn(8, in_f, dtype=dtype, device="cuda", requires_grad=True) + dist.broadcast(inp, src=0) + + # _finalize_wgrad always accumulates into main_grad; allocate before backward. + l0.weight.main_grad = torch.zeros(l0.weight.shape, dtype=dtype, device="cuda") + l1.weight.main_grad = torch.zeros(l1.weight.shape, dtype=dtype, device="cuda") + + out = l0(inp, is_first_microbatch=True) + l1(inp, is_first_microbatch=True) + out.sum().backward() + + # Both weights' main_grad should have been updated + for lyr in [l0, l1]: + w = lyr.weight + assert w.main_grad is not None, f"No main_grad on {lyr.__class__.__name__}.weight" + dist.destroy_process_group() + + +class TestETPWgradRS: + def test_wgrad_shape_matches_shard(self): + _requires_multi_gpu(4) + _run_distributed(_worker_wgrad_shape, 4) + + def test_multilayer_deferred_rs(self): + _requires_multi_gpu(4) + _run_distributed(_worker_multilayer_deferred_rs, 4) + + +# --------------------------------------------------------------------------- +# 10. Multiple microbatches: output must be consistent when weight unchanged +# --------------------------------------------------------------------------- + +def _worker_microbatches(rank, world_size, port): + _dist_init(rank, world_size, port) + torch.manual_seed(0) + batch, in_f, out_f = 8, 64, 128 + dtype = torch.bfloat16 + etp_group = dist.new_group(list(range(world_size))) + + layer = te.Linear( + in_features=in_f, out_features=out_f, + bias=False, params_dtype=dtype, + device="cuda", etp_group=etp_group, + ) + inp = torch.randn(batch, in_f, dtype=dtype, device="cuda") + dist.broadcast(inp, src=0) + + # First microbatch + out1 = layer(inp, is_first_microbatch=True).detach().clone() + + # Second microbatch with same weight (skip_weight_cast=True path) + out2 = layer(inp, is_first_microbatch=False).detach() + + assert torch.allclose(out1, out2), \ + f"Microbatch outputs differ; max_diff={(out1-out2).abs().max():.6f}" + dist.destroy_process_group() + + +class TestETPMicrobatches: + def test_consistent_across_microbatches(self): + _requires_multi_gpu(4) + _run_distributed(_worker_microbatches, 4) + + +# --------------------------------------------------------------------------- +# 11. NVFP4 + ETP: Linear forward/backward, quantized shard setup +# --------------------------------------------------------------------------- + +def _worker_nvfp4_linear(rank, world_size, port): + """Verify that ETP Linear correctly quantizes, all-gathers, and computes with NVFP4.""" + _dist_init(rank, world_size, port) + torch.manual_seed(0) + # batch=32: NVFP4 wgrad GEMM (K=batch) requires K divisible by 32 + batch, in_f, out_f = 32, 64, 128 # out_f % (16*world_size)==0 → no padding + dtype = torch.bfloat16 + etp_group = dist.new_group(list(range(world_size))) + + layer = te.Linear( + in_features=in_f, out_features=out_f, + bias=False, params_dtype=dtype, + device="cuda", etp_group=etp_group, + ) + inp = torch.randn(batch, in_f, dtype=dtype, device="cuda", requires_grad=True) + dist.broadcast(inp, src=0) + + # Forward under NVFP4 recipe – triggers setup() and NVFP4 quantization + recipe = NVFP4BlockScaling() + with fp8_autocast(enabled=True, fp8_recipe=recipe): + out = layer(inp, is_first_microbatch=True) + + # After the first forward pass setup() must have created a quantized shard + w = layer.weight + assert w.quantized is not None, "NVFP4 quantized shard must be set after setup()" + assert isinstance(w.quantized, QuantizedTensor), \ + f"weight.quantized should be QuantizedTensor, got {type(w.quantized)}" + + assert out.shape == (batch, out_f), f"unexpected output shape {out.shape}" + assert torch.isfinite(out).all(), "NVFP4 ETP output has non-finite values" + + # Second microbatch reuses cached quantized weight (skip_weight_cast path) + with fp8_autocast(enabled=True, fp8_recipe=recipe): + out2 = layer(inp.detach(), is_first_microbatch=False) + assert torch.isfinite(out2).all(), "NVFP4 ETP second-microbatch output has non-finite values" + + dist.destroy_process_group() + + +def _worker_nvfp4_linear_unaligned(rank, world_size, port): + """Verify NVFP4 ETP when out_features is not aligned to 16*world_size (padding path). + + out_f is chosen to be divisible by 8 (satisfies NVFP4 GEMM alignment) but not by + 16*world_size (so padding is needed). The last ETP rank receives a shard that is + zero-padded to reach the shard_size boundary. After all-gather, _strip_padding + removes the padded rows from the gathered weight before the GEMM, so the output + has the original out_f columns. + """ + _dist_init(rank, world_size, port) + torch.manual_seed(0) + alignment = 16 * world_size # 64 for world_size=4 + # Choose out_f divisible by 8 (NVFP4 GEMM constraint) but not by 64 (ETP alignment). + # With out_f=56: pad_length=8, shard_size=16, last rank gets 8 rows padded to 16. + out_f = alignment - 8 # 56 for world_size=4 + in_f = 64 + batch = 32 + dtype = torch.bfloat16 + etp_group = dist.new_group(list(range(world_size))) + + layer = te.Linear( + in_features=in_f, out_features=out_f, + bias=False, params_dtype=dtype, + device="cuda", etp_group=etp_group, + ) + inp = torch.randn(batch, in_f, dtype=dtype, device="cuda", requires_grad=True) + dist.broadcast(inp, src=0) + + with fp8_autocast(enabled=True, fp8_recipe=NVFP4BlockScaling()): + out = layer(inp, is_first_microbatch=True) + + # After _strip_padding removes the padded rows, output has out_f (not padded) cols. + assert out.shape == (batch, out_f), f"unexpected output shape {out.shape}" + assert torch.isfinite(out).all(), "NVFP4 ETP (unaligned) output has non-finite values" + dist.destroy_process_group() + + +class TestNVFP4LinearETP: + def test_forward_backward(self): + _requires_nvfp4() + _requires_multi_gpu(4) + _run_distributed(_worker_nvfp4_linear, 4) + + def test_forward_unaligned_padding(self): + _requires_nvfp4() + _requires_multi_gpu(4) + _run_distributed(_worker_nvfp4_linear_unaligned, 4) + + +# --------------------------------------------------------------------------- +# 12. NVFP4 + ETP: GroupedLinear forward/backward (coalesced batched all-gather) +# --------------------------------------------------------------------------- + +def _worker_nvfp4_grouped_linear(rank, world_size, port, num_gemms): + """Verify NVFP4 ETP with GroupedLinear (uses grouped_gather_along_first_dim).""" + _dist_init(rank, world_size, port) + torch.manual_seed(0) + # NVFP4 split_quantize constraints: in_f % 128 == 0, tokens_per_expert % 64 == 0 + # (Hadamard transform requirement), and K=tokens_per_expert % 32 == 0 for wgrad. + in_f, out_f, total_tokens = 128, 256, num_gemms * 64 + dtype = torch.bfloat16 + etp_group = dist.new_group(list(range(world_size))) + + layer = te.GroupedLinear( + num_gemms=num_gemms, in_features=in_f, out_features=out_f, + bias=False, params_dtype=dtype, + device="cuda", etp_group=etp_group, + ) + assert isinstance(layer.weight0, ETPShardedParam) + + m_splits = [total_tokens // num_gemms] * num_gemms + m_splits[-1] += total_tokens - sum(m_splits) + + inp = torch.randn(total_tokens, in_f, dtype=dtype, device="cuda", requires_grad=True) + dist.broadcast(inp, src=0) + + with fp8_autocast(enabled=True, fp8_recipe=NVFP4BlockScaling()): + out = layer(inp, m_splits=m_splits, is_first_microbatch=True) + + assert out.shape == (total_tokens, out_f), f"unexpected output shape {out.shape}" + assert torch.isfinite(out).all(), "NVFP4 GroupedLinear ETP output has non-finite values" + + # All expert weight shards should be quantized after setup() + for i in range(num_gemms): + name = f"weight{i}" + w = getattr(layer, name) + assert isinstance(w, ETPShardedParam) + assert w.quantized is not None, f"{name}.quantized not set after NVFP4 setup()" + assert isinstance(w.quantized, QuantizedTensor), \ + f"{name}.quantized should be QuantizedTensor, got {type(w.quantized)}" + + for i in range(num_gemms): + w = getattr(layer, f"weight{i}") + w.main_grad = torch.zeros(w.shape, dtype=dtype, device="cuda") + out.sum().backward() + assert inp.grad is not None and inp.grad.shape == inp.shape + dist.destroy_process_group() + + +class TestNVFP4GroupedLinearETP: + @pytest.mark.parametrize("num_gemms", [2, 4]) + def test_forward_backward(self, num_gemms): + _requires_nvfp4() + _requires_multi_gpu(4) + _run_distributed(_worker_nvfp4_grouped_linear, 4, num_gemms) + + +# --------------------------------------------------------------------------- +# 13. MXFP8 + ETP: Linear forward/backward, quantized shard setup +# --------------------------------------------------------------------------- + +def _worker_mxfp8_linear(rank, world_size, port): + """Verify that ETP Linear correctly quantizes, all-gathers, and computes with MXFP8.""" + from transformer_engine.common.recipe import MXFP8BlockScaling + _dist_init(rank, world_size, port) + torch.manual_seed(0) + # batch=32: MXFP8 wgrad GEMM (K=batch) requires K divisible by MXFP8_BLOCK_SCALING_SIZE=32 + batch, in_f, out_f = 32, 64, 128 # out_f % (16*world_size)==0 → no padding + dtype = torch.bfloat16 + etp_group = dist.new_group(list(range(world_size))) + + layer = te.Linear( + in_features=in_f, out_features=out_f, + bias=False, params_dtype=dtype, + device="cuda", etp_group=etp_group, + ) + inp = torch.randn(batch, in_f, dtype=dtype, device="cuda", requires_grad=True) + dist.broadcast(inp, src=0) + + # Forward under MXFP8 recipe – triggers setup() and MXFP8 quantization + recipe = MXFP8BlockScaling() + with fp8_autocast(enabled=True, fp8_recipe=recipe): + out = layer(inp, is_first_microbatch=True) + + # After the first forward pass setup() must have created a quantized shard + w = layer.weight + assert w.quantized is not None, "MXFP8 quantized shard must be set after setup()" + assert isinstance(w.quantized, QuantizedTensor), \ + f"weight.quantized should be QuantizedTensor, got {type(w.quantized)}" + + assert out.shape == (batch, out_f), f"unexpected output shape {out.shape}" + assert torch.isfinite(out).all(), "MXFP8 ETP output has non-finite values" + + # Backward should complete without error + layer.weight.main_grad = torch.zeros(layer.weight.shape, dtype=dtype, device="cuda") + out.sum().backward() + assert inp.grad is not None + assert inp.grad.shape == inp.shape + + # Second microbatch reuses cached quantized weight (skip_weight_cast path) + with fp8_autocast(enabled=True, fp8_recipe=recipe): + out2 = layer(inp.detach(), is_first_microbatch=False) + assert torch.isfinite(out2).all(), "MXFP8 ETP second-microbatch output has non-finite values" + + dist.destroy_process_group() + + +def _worker_mxfp8_linear_unaligned(rank, world_size, port): + """Verify MXFP8 ETP when out_features is not aligned to 16*world_size (padding path). + + MXFP8 requires tensor dims divisible by 32, so shard_size (= M_padded / world_size) + must be a multiple of 32. With world_size=4 this requires M_padded % 128 == 0. + out_f=120 gives M_padded=128, shard_size=32 (32 % 32 == 0). The last rank has + 24 real rows zero-padded to 32. After all-gather, _strip_padding removes the padded + rows before the GEMM, so the output has the original out_f columns. + """ + from transformer_engine.common.recipe import MXFP8BlockScaling + _dist_init(rank, world_size, port) + torch.manual_seed(0) + # out_f=120: M_padded=128, shard_size=32, last rank has 24 rows padded to 32. + # 120 is divisible by 8 (GEMM constraint), not by 64 (ETP alignment → padding needed). + out_f = 120 + in_f = 64 + batch = 32 + dtype = torch.bfloat16 + etp_group = dist.new_group(list(range(world_size))) + + layer = te.Linear( + in_features=in_f, out_features=out_f, + bias=False, params_dtype=dtype, + device="cuda", etp_group=etp_group, + ) + inp = torch.randn(batch, in_f, dtype=dtype, device="cuda", requires_grad=True) + dist.broadcast(inp, src=0) + + with fp8_autocast(enabled=True, fp8_recipe=MXFP8BlockScaling()): + out = layer(inp, is_first_microbatch=True) + + # After _strip_padding removes the padded rows, output has out_f (not padded) cols. + assert out.shape == (batch, out_f), f"unexpected output shape {out.shape}" + assert torch.isfinite(out).all(), "MXFP8 ETP (unaligned) output has non-finite values" + dist.destroy_process_group() + + +def _requires_mxfp8(): + available, reason = is_mxfp8_available(return_reason=True) + if not available: + pytest.skip(f"MXFP8 not available: {reason}") + + +class TestMXFP8LinearETP: + def test_forward_backward(self): + _requires_mxfp8() + _requires_multi_gpu(4) + _run_distributed(_worker_mxfp8_linear, 4) + + def test_forward_unaligned_padding(self): + _requires_mxfp8() + _requires_multi_gpu(4) + _run_distributed(_worker_mxfp8_linear_unaligned, 4) + + +# --------------------------------------------------------------------------- +# 14. ETPConfig / update_config +# --------------------------------------------------------------------------- + +class TestETPConfig: + + def test_update_pad_for_alignment(self): + original = etp_module.ETP_CONFIG.pad_for_alignment + try: + etp_module.update_config(pad_for_alignment=8) + assert etp_module.ETP_CONFIG.pad_for_alignment == 8 + finally: + etp_module.update_config(pad_for_alignment=original) + + def test_update_weight_prefetch(self): + original = etp_module.ETP_CONFIG.weight_prefetch + try: + etp_module.update_config(weight_prefetch=False) + assert etp_module.ETP_CONFIG.weight_prefetch is False + finally: + etp_module.update_config(weight_prefetch=original) + + def test_invalid_key_raises(self): + with pytest.raises(ValueError, match="Unknown ETP config option"): + etp_module.update_config(nonexistent_key=123) + + +# --------------------------------------------------------------------------- +# 15. ETPShardedParam properties – shape computations and padding +# --------------------------------------------------------------------------- + +class TestETPShardedParamProperties: + + class _FakeGroup: + def __init__(self, size=4, rank=0): + self._size = size + self._rank = rank + def size(self): return self._size + def rank(self): return self._rank + + def _make_param(self, shape, pad_length=0, group_size=4, group_rank=0, + is_padded_last_rank=False): + p = ETPShardedParam(torch.zeros(*shape)) + p.group = self._FakeGroup(size=group_size, rank=group_rank) + p.pad_length = pad_length + p.is_padded_last_rank = is_padded_last_rank + p.expert_idx = None + return p + + # --- _unsharded_shape_padded --- + + def test_unsharded_shape_padded_no_padding(self): + # shape=(8, 4), group_size=4 → 8*4=32 rows, no padding + p = self._make_param((8, 4), pad_length=0, group_size=4, group_rank=2) + assert p._unsharded_shape_padded == (32, 4) + + def test_unsharded_shape_padded_last_rank_with_padding(self): + # shard has 15 real rows, pad_length=1, last rank → (15+1)*4=64 + p = self._make_param((15, 32), pad_length=1, group_size=4, group_rank=3, + is_padded_last_rank=True) + assert p._unsharded_shape_padded == (64, 32) + + def test_unsharded_shape_padded_non_last_rank_with_padding(self): + # Non-last rank: pad_length metadata set but shape just multiplied + p = self._make_param((16, 32), pad_length=1, group_size=4, group_rank=0, + is_padded_last_rank=False) + assert p._unsharded_shape_padded == (64, 32) + + # --- _unsharded_shape --- + + def test_unsharded_shape_no_padding(self): + p = self._make_param((8, 4), pad_length=0, group_size=4, group_rank=0) + assert p._unsharded_shape == (32, 4) + + def test_unsharded_shape_strips_padding(self): + # padded = 64, strip 1 → 63 + p = self._make_param((15, 32), pad_length=1, group_size=4, group_rank=3, + is_padded_last_rank=True) + assert p._unsharded_shape == (63, 32) + + # --- get_padded_shard --- + + def test_get_padded_shard_identity_when_no_padding(self): + p = self._make_param((6, 4), pad_length=0) + result = p.get_padded_shard() + assert result is p # identity – no copy needed + + def test_get_padded_shard_identity_non_last_rank(self): + # pad_length > 0 but not the padded last rank → no padding added + p = self._make_param((16, 4), pad_length=1, group_size=4, group_rank=0, + is_padded_last_rank=False) + result = p.get_padded_shard() + assert result is p + + def test_get_padded_shard_appends_zero_rows(self): + p = self._make_param((6, 4), pad_length=2, group_size=4, group_rank=3, + is_padded_last_rank=True) + padded = p.get_padded_shard() + assert padded.shape == (8, 4), f"Expected (8,4), got {padded.shape}" + assert torch.all(padded[6:] == 0), "Padding rows must be zero" + + # --- _strip_padding --- + + def test_strip_padding_identity_no_padding(self): + p = self._make_param((8, 4), pad_length=0) + t = torch.randn(32, 4) + assert p._strip_padding(t) is t + + def test_strip_padding_plain_tensor(self): + # Gathered weight [32, 4] with pad_length=1 → strip 1 row → [31, 4] + p = self._make_param((7, 4), pad_length=1, group_size=4, group_rank=0) + t = torch.randn(32, 4) + result = p._strip_padding(t) + assert result.shape == (31, 4) + assert torch.equal(result, t[:-1]) + + def test_strip_padding_multi_row(self): + # pad_length=4 strips 4 rows + p = self._make_param((12, 8), pad_length=4, group_size=4, group_rank=0) + t = torch.ones(64, 8) + result = p._strip_padding(t) + assert result.shape == (60, 8) + + +# --------------------------------------------------------------------------- +# 16. _get_cache_key – expert vs non-expert, fwd vs bwd +# --------------------------------------------------------------------------- + +class TestETPCacheKey: + + class _FakeGroup: + def size(self): return 4 + def rank(self): return 0 + + def _param(self, shape=(16, 32), expert_idx=None): + p = ETPShardedParam(torch.zeros(*shape)) + p.group = self._FakeGroup() + p.expert_idx = expert_idx + p.pad_length = 0 + p.is_padded_last_rank = False + return p + + def test_non_expert_key_same_for_fwd_bwd(self): + """Non-routed params produce the same cache key for fwd and bwd.""" + p = self._param(expert_idx=None) + assert p._get_cache_key(torch.bfloat16, fwd=True, reduce_scatter=False) == \ + p._get_cache_key(torch.bfloat16, fwd=False, reduce_scatter=False) + + def test_expert_key_differs_fwd_bwd(self): + """For quantized (non-torch.dtype) recipes, expert fwd vs bwd keys differ.""" + p = self._param(expert_idx=0) + # _get_cache_key differentiates fwd/bwd only for non-torch.dtype objects + # (e.g. quantized recipe dtype descriptors). Use a mock to trigger that path. + mock_dtype = "fp8" + assert p._get_cache_key(mock_dtype, fwd=True, reduce_scatter=False) != \ + p._get_cache_key(mock_dtype, fwd=False, reduce_scatter=False) + + def test_different_expert_idx_different_keys(self): + """Two experts with same shape but different indices get distinct keys.""" + p0 = self._param(expert_idx=0) + p1 = self._param(expert_idx=1) + assert p0._get_cache_key(torch.bfloat16, fwd=True, reduce_scatter=False) != \ + p1._get_cache_key(torch.bfloat16, fwd=True, reduce_scatter=False) + + def test_same_expert_idx_same_key(self): + """Same-shaped experts with the same idx share a cache key (cross-layer buffer reuse).""" + p_l0 = self._param(expert_idx=0) + p_l1 = self._param(expert_idx=0) + assert p_l0._get_cache_key(torch.bfloat16, fwd=True, reduce_scatter=False) == \ + p_l1._get_cache_key(torch.bfloat16, fwd=True, reduce_scatter=False) + + def test_different_dtypes_different_keys(self): + p = self._param() + assert p._get_cache_key(torch.bfloat16, fwd=True, reduce_scatter=False) != \ + p._get_cache_key(torch.float32, fwd=True, reduce_scatter=False) + + def test_rs_key_differs_from_ag_key(self): + """reduce_scatter=True key must differ from reduce_scatter=False key.""" + p = self._param() + assert p._get_cache_key(torch.bfloat16, fwd=True, reduce_scatter=False) != \ + p._get_cache_key(torch.bfloat16, fwd=True, reduce_scatter=True) + + +# --------------------------------------------------------------------------- +# 17. ETPWeightCache.take() deferred vs get() immediate pool return +# --------------------------------------------------------------------------- + +class TestETPCacheRelease: + """Tests for ETPWeightCache reserve/get/release semantics.""" + + class _FakeGroup: + def size(self): return 2 + def rank(self): return 0 + + def _param(self, shape=(8, 4)): + p = ETPShardedParam(torch.zeros(*shape)) + p.group = self._FakeGroup() + p.expert_idx = None + p.pad_length = 0 + p.is_padded_last_rank = False + p._quantizer = None + return p + + def test_release_returns_buffer_to_pool(self): + """release() puts the buffer back so the next reserve+get reuses it.""" + cache = ETPWeightCache() + p = self._param() + t1 = cache.reserve(p, torch.bfloat16, fwd=True) + buf1 = cache.get(t1) + cache.release(t1) + # New ticket should pop buf1 from pool + t2 = cache.reserve(p, torch.bfloat16, fwd=True) + buf2 = cache.get(t2) + assert buf2 is buf1, "Buffer should be reused after release()" + cache.release(t2) + + def test_without_release_pool_stays_empty(self): + """Without release(), subsequent reserves allocate fresh buffers.""" + cache = ETPWeightCache() + p = self._param() + t1 = cache.reserve(p, torch.bfloat16, fwd=True) + buf1 = cache.get(t1) + # Do NOT release t1 — pool stays empty + t2 = cache.reserve(p, torch.bfloat16, fwd=True) + buf2 = cache.get(t2) + assert buf2 is not buf1, "Without release, a fresh buffer must be allocated" + + def test_get_same_ticket_returns_same_buf(self): + """get() is idempotent — calling it twice returns the same buffer.""" + cache = ETPWeightCache() + p = self._param() + t = cache.reserve(p, torch.bfloat16, fwd=True) + buf_a = cache.get(t) + buf_b = cache.get(t) + assert buf_a is buf_b + cache.release(t) + + def test_release_invalid_ticket_raises(self): + cache = ETPWeightCache() + with pytest.raises(KeyError): + cache.release(9999) + + +# --------------------------------------------------------------------------- +# 18. tag_etp_params_with_names – _debug_name population +# --------------------------------------------------------------------------- + +class TestTagETPParamsWithNames: + + def test_debug_name_populated_for_etp_param(self): + """ETPShardedParam._debug_name is set to the dotted parameter path.""" + class _FakeGroup: + def size(self): return 1 + def rank(self): return 0 + + model = nn.Linear(4, 8, bias=False) + w = ETPShardedParam(torch.randn(8, 4)) + w.group = _FakeGroup() + model._parameters['weight'] = w + + etp_module.tag_etp_params_with_names(model) + assert w._debug_name == 'weight', \ + f"Expected 'weight', got '{w._debug_name}'" + + def test_nested_module_debug_name(self): + """Nested module produces a dotted debug name.""" + class _FakeGroup: + def size(self): return 1 + def rank(self): return 0 + + outer = nn.Sequential(nn.Linear(4, 8, bias=False)) + w = ETPShardedParam(torch.randn(8, 4)) + w.group = _FakeGroup() + outer._modules['0']._parameters['weight'] = w + + etp_module.tag_etp_params_with_names(outer) + assert w._debug_name == '0.weight', \ + f"Expected '0.weight', got '{w._debug_name}'" + + def test_non_etp_params_are_skipped(self): + """Plain nn.Parameter instances are silently ignored.""" + model = nn.Linear(4, 8) + etp_module.tag_etp_params_with_names(model) # must not raise + + +# --------------------------------------------------------------------------- +# 19. _finalize_wgrad – strip padding, fuse accumulation, hook invocation +# --------------------------------------------------------------------------- + +class TestFinalizeWgrad: + """Tests for ETPShardedParam._finalize_wgrad(param, wgrad_rs). + + Current behaviour: always accumulates wgrad_rs into param.main_grad, + strips padding when is_padded_last_rank=True, resets rs_state to NONE, + and returns a dummy-zero grad tensor with the same shape as main_grad. + """ + + class _FakeGroup: + def size(self): return 2 + def rank(self): return 0 + + def _param(self, shape=(8, 4), pad_length=0, is_padded_last_rank=False, device="cuda"): + p = ETPShardedParam(torch.zeros(*shape, device=device)) + p.group = self._FakeGroup() + p.pad_length = pad_length + p.is_padded_last_rank = is_padded_last_rank + p.main_grad = torch.zeros(*shape, device=device) + return p + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") + def test_accumulates_into_main_grad(self): + p = self._param() + wgrad = torch.ones(8, 4, device="cuda") + ETPShardedParam._finalize_wgrad(p, wgrad) + assert torch.all(p.main_grad == 1), "main_grad should equal wgrad after accumulation" + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") + def test_returns_dummy_zero_grad(self): + p = self._param() + wgrad = torch.ones(8, 4, device="cuda") + result = ETPShardedParam._finalize_wgrad(p, wgrad) + assert result.shape == p.shape, "dummy grad shape must match shard shape" + assert torch.all(result == 0), "dummy grad must be zeroes" + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") + def test_strips_padding_for_padded_rank(self): + # Shard has 7 real rows, pad_length=1, is_padded_last_rank=True. + # RS output has 8 rows (7 real + 1 pad); strip to 7. + p = self._param(shape=(7, 4), pad_length=1, is_padded_last_rank=True) + # main_grad must match the real shard shape (7 rows) + p.main_grad = torch.zeros(7, 4, device="cuda") + wgrad = torch.ones(8, 4, device="cuda") + ETPShardedParam._finalize_wgrad(p, wgrad) + assert torch.all(p.main_grad == 1), "main_grad (7 rows) should be fully updated" + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") + def test_rs_state_reset_to_none(self): + p = self._param() + p._set_rs_state(ETPWeightState.DATA_READY_SYNC) + wgrad = torch.ones(8, 4, device="cuda") + ETPShardedParam._finalize_wgrad(p, wgrad) + assert p.rs_state == ETPWeightState.NONE, "rs_state should be reset to NONE" + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") + def test_grad_added_to_main_grad_flag(self): + p = self._param() + p.grad_added_to_main_grad = False + wgrad = torch.ones(8, 4, device="cuda") + ETPShardedParam._finalize_wgrad(p, wgrad) + assert p.grad_added_to_main_grad is True + + +# --------------------------------------------------------------------------- +# 20. wrap_module_params_etp is a no-op when etp_group.size() == 1 +# --------------------------------------------------------------------------- + +class TestETPGroupSizeOne: + + class _SingletonGroup: + def size(self): return 1 + def rank(self): return 0 + + def test_no_sharding_when_etp_size_one(self): + """wrap_module_params_etp must be a no-op for a singleton ETP group.""" + mod = nn.Linear(32, 64, bias=False) + original_weight = mod.weight + wrap_module_params_etp(mod, ['weight'], self._SingletonGroup()) + assert mod.weight is original_weight, \ + "etp_group.size()==1 should leave parameters unchanged" + assert not isinstance(mod.weight, ETPShardedParam) + + +# --------------------------------------------------------------------------- +# 21. weight_prefetch=False: forward still produces correct output +# --------------------------------------------------------------------------- + +def _worker_prefetch_disabled(rank, world_size, port): + _dist_init(rank, world_size, port) + torch.manual_seed(0) + in_f, out_f = 32, 64 + dtype = torch.bfloat16 + etp_group = dist.new_group(list(range(world_size))) + + etp_module.update_config(weight_prefetch=False) + try: + l0 = te.Linear(in_features=in_f, out_features=out_f, bias=False, + params_dtype=dtype, device="cuda", etp_group=etp_group) + l1 = te.Linear(in_features=in_f, out_features=out_f, bias=False, + params_dtype=dtype, device="cuda", etp_group=etp_group) + + inp = torch.randn(4, in_f, dtype=dtype, device="cuda") + dist.broadcast(inp, src=0) + + # Single forward pass: builds chain and verifies output is correct + out = l0(inp, is_first_microbatch=True) + l1(inp, is_first_microbatch=True) + + # Chain should still be wired even with prefetch disabled + assert l0.weight.next_w is l1.weight + assert torch.isfinite(out).all(), "Non-finite output with prefetch disabled" + finally: + etp_module.update_config(weight_prefetch=True) + dist.destroy_process_group() + + +class TestETPPrefetchDisabled: + def test_forward_works_without_prefetch(self): + _requires_multi_gpu(4) + _run_distributed(_worker_prefetch_disabled, 4) + + +# --------------------------------------------------------------------------- +# 22. fuse_wgrad_accumulation=True: wgrad is accumulated into main_grad +# --------------------------------------------------------------------------- + +def _worker_fuse_wgrad(rank, world_size, port): + _dist_init(rank, world_size, port) + torch.manual_seed(0) + in_f, out_f = 32, 128 # out_f % (16*world_size)==0, no padding + dtype = torch.bfloat16 + etp_group = dist.new_group(list(range(world_size))) + + layer = te.Linear( + in_features=in_f, out_features=out_f, + bias=False, params_dtype=dtype, + device="cuda", etp_group=etp_group, + fuse_wgrad_accumulation=True, + ) + + # Allocate main_grad on the local shard shape + w = layer.weight + w.main_grad = torch.zeros(w.shape, dtype=dtype, device="cuda") + + inp = torch.randn(8, in_f, dtype=dtype, device="cuda", requires_grad=True) + dist.broadcast(inp, src=0) + + layer(inp, is_first_microbatch=True).sum().backward() + + # With fused accumulation, wgrad was added into main_grad + assert torch.any(w.main_grad != 0), \ + "main_grad should have been updated by fused wgrad accumulation" + dist.destroy_process_group() + + +class TestFuseWgradAccumulation: + def test_wgrad_accumulated_into_main_grad(self): + _requires_multi_gpu(4) + _run_distributed(_worker_fuse_wgrad, 4) + + +# --------------------------------------------------------------------------- +# 23. _grad_accum_hook is called after reduce-scatter +# --------------------------------------------------------------------------- + +def _worker_main_grad_updated_after_bwd(rank, world_size, port): + """After backward, _finalize_wgrad must have accumulated wgrad into main_grad.""" + _dist_init(rank, world_size, port) + torch.manual_seed(0) + in_f, out_f = 32, 64 + dtype = torch.bfloat16 + etp_group = dist.new_group(list(range(world_size))) + + layer = te.Linear( + in_features=in_f, out_features=out_f, + bias=False, params_dtype=dtype, + device="cuda", etp_group=etp_group, + ) + + # _finalize_wgrad always accumulates into main_grad; allocate before backward. + layer.weight.main_grad = torch.zeros(layer.weight.shape, dtype=dtype, device="cuda") + + inp = torch.randn(8, in_f, dtype=dtype, device="cuda", requires_grad=True) + dist.broadcast(inp, src=0) + layer(inp, is_first_microbatch=True).sum().backward() + + assert torch.any(layer.weight.main_grad != 0), \ + "main_grad should have been updated by _finalize_wgrad after reduce-scatter" + dist.destroy_process_group() + + +class TestETPGradAccumHook: + def test_main_grad_updated_after_backward(self): + _requires_multi_gpu(4) + _run_distributed(_worker_main_grad_updated_after_bwd, 4) + + diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index a105a0343f..032d635e61 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -163,6 +163,7 @@ list(APPEND transformer_engine_cuda_sources recipe/current_scaling.cu recipe/delayed_scaling.cu recipe/fp8_block_scaling.cu + recipe/multi_amax.cu comm_gemm_overlap/userbuffers/userbuffers.cu) list(APPEND transformer_engine_cuda_arch_specific_sources diff --git a/transformer_engine/common/include/transformer_engine/recipe.h b/transformer_engine/common/include/transformer_engine/recipe.h index cad27a2992..2244056823 100644 --- a/transformer_engine/common/include/transformer_engine/recipe.h +++ b/transformer_engine/common/include/transformer_engine/recipe.h @@ -99,6 +99,26 @@ void nvte_compute_amax(const NVTETensor input, NVTETensor output, cudaStream_t s void nvte_compute_amax_with_config(const NVTETensor input, NVTETensor output, const NVTEQuantizationConfig config, cudaStream_t stream); +/*! \brief Compute amax for a list of independent tensors in a single kernel launch. + * + * Unlike nvte_group_amax (which requires a single contiguous input split along dim 0), + * this API accepts arrays of independent input tensors, each with its own allocation. + * Designed for the ETP grouped-experts case where per-expert weights live in separate + * buffers. For each i in [0, num_tensors), computes amax(inputs[i]) and writes it to + * outputs[i]'s amax buffer. outputs[i] must be an FP8 per-tensor scaling or NVFP4 1D + * scaling tensor. All inputs must share the same dtype. If the list exceeds the + * per-launch batch capacity, it is internally chunked. + * + * \param[in] inputs Array of input tensors (unquantized). Size num_tensors. + * \param[in,out] outputs Array of output tensors. Only the amax is updated. + * Size num_tensors. + * \param[in] num_tensors Number of tensors. + * \param[in] config Quantization configuration (for noop_tensor). May be NULL. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_multi_compute_amax(const NVTETensor *inputs, NVTETensor *outputs, size_t num_tensors, + const NVTEQuantizationConfig config, cudaStream_t stream); + /*! \brief Update an FP8 tensor's scale based on its amax. * * This is only supported for FP8 tensors with per-tensor scaling. diff --git a/transformer_engine/common/recipe/multi_amax.cu b/transformer_engine/common/recipe/multi_amax.cu new file mode 100644 index 0000000000..5420dde587 --- /dev/null +++ b/transformer_engine/common/recipe/multi_amax.cu @@ -0,0 +1,274 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include + +#include +#include + +#include "../common.h" +#include "../util/logging.h" +#include "../util/vectorized_pointwise.h" +#include "recipe_common.cuh" + +namespace transformer_engine { +namespace { + +constexpr int multi_amax_kernel_threads = 512; +// Per-launch capacity. kMaxTensorsPerBatch * ~40 bytes per slot keeps the args +// struct within the 4KB kernel parameter limit with comfortable headroom. +constexpr int kMaxTensorsPerBatch = 64; + +struct MultiAmaxArgs { + const void *input_list[kMaxTensorsPerBatch]; + void *output_rowwise_amax_list[kMaxTensorsPerBatch]; + void *output_columnwise_amax_list[kMaxTensorsPerBatch]; + size_t input_numel[kMaxTensorsPerBatch]; + size_t num_aligned_elements[kMaxTensorsPerBatch]; + int num_tensors; +}; + +// Zero out every output amax slot (rowwise + columnwise, deduped) in a single launch. +// Respects the noop_ptr contract shared with the single-tensor amax path. +__launch_bounds__(multi_amax_kernel_threads) __global__ + void MultiZeroAmaxKernel(MultiAmaxArgs args, const float *noop_ptr) { + if (noop_ptr != nullptr && noop_ptr[0] == 1.0f) { + return; + } + int tid = blockIdx.x * blockDim.x + threadIdx.x; + int stride = blockDim.x * gridDim.x; + for (; tid < args.num_tensors; tid += stride) { + float *rw = static_cast(args.output_rowwise_amax_list[tid]); + float *cw = static_cast(args.output_columnwise_amax_list[tid]); + if (rw != nullptr) { + *rw = 0.0f; + } + if (cw != nullptr && cw != rw) { + *cw = 0.0f; + } + } +} + +// Per-tensor amax with one block-strip per tensor. blockIdx.y selects the +// tensor; blockIdx.x is the work chunk within that tensor. Each block +// vector-loads the tensor, reduces across threads, and atomicMaxFloats the +// result into BOTH output amax slots (rowwise + columnwise, deduped). This +// subsumes the per-expert D2D copy that the single-tensor path does after the +// amax kernel. +template +__launch_bounds__(multi_amax_kernel_threads) __global__ + void MultiAmaxKernel(MultiAmaxArgs args, const float *noop_ptr) { + if (noop_ptr != nullptr && noop_ptr[0] == 1.0f) { + return; + } + + const int t_idx = blockIdx.y; + if (t_idx >= args.num_tensors) { + return; + } + + const InputType *input = static_cast(args.input_list[t_idx]); + const size_t N = args.input_numel[t_idx]; + if (N == 0) { + return; + } + const size_t M = args.num_aligned_elements[t_idx]; + + VectorizedLoader loader(input, N); + InputType max = InputType{0.f}; + const int warp_id = threadIdx.x / THREADS_PER_WARP; + + for (size_t tid = blockIdx.x * blockDim.x + threadIdx.x; tid < M; + tid += gridDim.x * blockDim.x) { + loader.load(tid, N); +#pragma unroll + for (int i = 0; i < nvec; ++i) { + const InputType val = static_cast(loader.separate()[i]); + __builtin_assume(max >= InputType{0.f}); + if constexpr (std::is_same_v) { +#if __CUDA_ARCH__ >= 800 + max = __hmax(__habs(val), max); +#else + max = static_cast<__nv_bfloat16>( + fmaxf(fabsf(static_cast(val)), static_cast(max))); +#endif + } else if constexpr (std::is_same_v) { + max = __hmax(__habs(val), max); + } else { + max = fmaxf(fabsf(val), max); + } + } + } + + // Reduce amax over block. + max = reduce_max(max, warp_id); + if (threadIdx.x == 0) { + float *rw = static_cast(args.output_rowwise_amax_list[t_idx]); + float *cw = static_cast(args.output_columnwise_amax_list[t_idx]); + if (rw != nullptr) { + atomicMaxFloat(rw, static_cast(max)); + } + if (cw != nullptr && cw != rw) { + atomicMaxFloat(cw, static_cast(max)); + } + } +} + +template +void launch_multi_amax_batch(const MultiAmaxArgs &args, size_t max_numel, Alignment align, + const float *noop_ptr, cudaStream_t stream) { + // Zero all amax outputs in one launch. + { + constexpr int threads = multi_amax_kernel_threads; + const int num_blocks = std::max(1, DIVUP(args.num_tensors, threads)); + MultiZeroAmaxKernel<<>>(args, noop_ptr); + NVTE_CHECK_CUDA(cudaGetLastError()); + } + + if (max_numel == 0) { + return; + } + + // Grid: y = tensor index, x = work chunks within the largest tensor. Blocks + // that exceed a shorter tensor's aligned element count bail out via the + // bounds check inside the kernel. + constexpr int nvec = 32 / sizeof(InputType); + constexpr size_t threads = multi_amax_kernel_threads; + const size_t max_aligned = (max_numel + nvec - 1) / nvec; + size_t num_blocks_x = DIVUP(max_aligned, threads); + constexpr size_t max_blocks = 65535; + num_blocks_x = std::min(num_blocks_x, max_blocks); + num_blocks_x = std::max(num_blocks_x, 1); + dim3 grid(num_blocks_x, static_cast(args.num_tensors), 1); + + switch (align) { + case Alignment::SAME_ALIGNED: + MultiAmaxKernel + <<>>(args, noop_ptr); + break; + case Alignment::SAME_UNALIGNED: + MultiAmaxKernel + <<>>(args, noop_ptr); + break; + case Alignment::DIFFERENT: + // Heterogeneous alignment across tensors — fall back to nvec=1, aligned=true path + // which is safe for any pointer alignment. + MultiAmaxKernel<1, true, InputType> + <<>>(args, noop_ptr); + break; + } + NVTE_CHECK_CUDA(cudaGetLastError()); +} + +// Fill one MultiAmaxArgs batch from a slice of the full input/output list. +// Returns (max_numel in this batch, worst-case alignment across the batch). +template +std::pair build_batch_args(const std::vector &inputs, + const std::vector &outputs, size_t start, + size_t count, MultiAmaxArgs &args) { + constexpr int nvec = 32 / sizeof(InputType); + size_t max_numel = 0; + // SAME_ALIGNED is the most optimistic; degrade to SAME_UNALIGNED if any + // tensor is merely same-layout but unaligned, to DIFFERENT if alignment + // varies across tensors. + Alignment batch_align = Alignment::SAME_ALIGNED; + for (size_t i = 0; i < count; ++i) { + const Tensor &inp = *inputs[start + i]; + Tensor &out = *outputs[start + i]; + const size_t N = inp.data.numel(); + void *rw_ptr = out.amax.dptr; + void *cw_ptr = out.columnwise_amax.dptr; + + args.input_list[i] = inp.data.dptr; + args.output_rowwise_amax_list[i] = rw_ptr; + args.output_columnwise_amax_list[i] = cw_ptr; + args.input_numel[i] = N; + args.num_aligned_elements[i] = get_num_aligned_elements(inp.data.dptr, N, nvec, + sizeof(InputType)); + max_numel = std::max(max_numel, N); + + // Fold this tensor's alignment into the batch decision. CheckAlignment on a + // single pointer yields SAME_ALIGNED or SAME_UNALIGNED; mixing the two across + // tensors means heterogeneous — switch to the DIFFERENT fall-back. + if (N > 0) { + Alignment a = CheckAlignment(N, nvec, static_cast(inp.data.dptr)); + if (batch_align == Alignment::SAME_ALIGNED && a == Alignment::SAME_UNALIGNED) { + batch_align = Alignment::SAME_UNALIGNED; + } else if (batch_align == Alignment::SAME_UNALIGNED && a == Alignment::SAME_ALIGNED) { + batch_align = Alignment::SAME_UNALIGNED; + } else if (a == Alignment::DIFFERENT) { + batch_align = Alignment::DIFFERENT; + } + } + } + args.num_tensors = static_cast(count); + return {max_numel, batch_align}; +} + +void multi_compute_amax_impl(const NVTETensor *inputs_, NVTETensor *outputs_, size_t num_tensors, + const NVTEQuantizationConfig config_, cudaStream_t stream) { + if (num_tensors == 0) { + return; + } + NVTE_CHECK(inputs_ != nullptr, "nvte_multi_compute_amax: inputs is NULL"); + NVTE_CHECK(outputs_ != nullptr, "nvte_multi_compute_amax: outputs is NULL"); + + // Convert, validate, collect into plain vectors. + std::vector inputs(num_tensors); + std::vector outputs(num_tensors); + DType input_dtype; + for (size_t i = 0; i < num_tensors; ++i) { + inputs[i] = convertNVTETensorCheck(inputs_[i]); + outputs[i] = convertNVTETensorCheck(outputs_[i]); + const auto &inp = *inputs[i]; + auto &out = *outputs[i]; + NVTE_CHECK(inp.scaling_mode == NVTE_DELAYED_TENSOR_SCALING, + "nvte_multi_compute_amax: input[", i, + "] must be unquantized, got scaling_mode=", to_string(inp.scaling_mode)); + NVTE_CHECK(!is_fp8_dtype(inp.data.dtype), + "nvte_multi_compute_amax: input[", i, + "] must be unquantized, got dtype=", to_string(inp.data.dtype)); + if (i == 0) { + input_dtype = inp.data.dtype; + } else { + NVTE_CHECK(inp.data.dtype == input_dtype, + "nvte_multi_compute_amax: all inputs must share dtype; input[0]=", + to_string(input_dtype), ", input[", i, "]=", to_string(inp.data.dtype)); + } + NVTE_CHECK(out.scaling_mode == NVTE_DELAYED_TENSOR_SCALING || + out.scaling_mode == NVTE_NVFP4_1D_SCALING, + "nvte_multi_compute_amax: output[", i, "] must be FP8 per-tensor or NVFP4 1D"); + NVTE_CHECK(out.amax.dptr != nullptr || out.columnwise_amax.dptr != nullptr, + "nvte_multi_compute_amax: output[", i, "] has no amax buffer"); + } + + const float *noop_ptr = nullptr; + if (config_ != nullptr) { + const QuantizationConfig *config_cpp = reinterpret_cast(config_); + const NVTETensor noop = config_cpp->noop_tensor; + noop_ptr = reinterpret_cast( + (noop != nullptr ? convertNVTETensorCheck(noop)->data.dptr : nullptr)); + } + + // Chunk across kMaxTensorsPerBatch launches (single launch in the common 8-expert case). + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input_dtype, IType, { + for (size_t start = 0; start < num_tensors; start += kMaxTensorsPerBatch) { + const size_t count = std::min(kMaxTensorsPerBatch, num_tensors - start); + MultiAmaxArgs args = {}; + auto [max_numel, batch_align] = build_batch_args(inputs, outputs, start, count, args); + launch_multi_amax_batch(args, max_numel, batch_align, noop_ptr, stream); + } + }); // NOLINT(*) +} + +} // anonymous namespace +} // namespace transformer_engine + +void nvte_multi_compute_amax(const NVTETensor *inputs, NVTETensor *outputs, size_t num_tensors, + const NVTEQuantizationConfig config, cudaStream_t stream) { + NVTE_API_CALL(nvte_multi_compute_amax); + transformer_engine::multi_compute_amax_impl(inputs, outputs, num_tensors, config, stream); +} diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index 6aab9938b3..e5ead50d09 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -365,11 +365,30 @@ class NVFP4Quantizer : public Quantizer { */ void quantize_with_amax(TensorWrapper& input, TensorWrapper& out); + /*! @brief Compute (and D2D fill) local amax only — no cast, no allreduce. + * + * Writes the local amax into out's rowwise and/or columnwise amax + * buffers. Callers are expected to perform a coalesced allreduce + * across the amax reduction group afterwards, then invoke + * quantize_cast_only to finish the cast with the reduced amax. + */ + void compute_amax_only(const TensorWrapper& input, TensorWrapper& out); + + /*! @brief Cast to NVFP4 assuming amax already reduced externally. + * + * Skips both local amax compute and the internal amax allreduce. + * Callers must guarantee out's amax buffers already hold the reduced + * amax (e.g. via compute_amax_only + allreduce_coalesced). + */ + void quantize_cast_only(const TensorWrapper& input, TensorWrapper& out, + const std::optional& noop_flag = std::nullopt); + std::vector get_scale_shape(const std::vector& shape, bool columnwise) const; private: void quantize_impl(const TensorWrapper& input, TensorWrapper& out, - const std::optional& noop_flag, bool compute_amax); + const std::optional& noop_flag, bool compute_amax, + bool skip_amax_reduction = false); }; std::unique_ptr convert_quantizer(py::handle quantizer); diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index e4d4e5094c..c9b5674426 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -285,6 +285,21 @@ std::vector rmsnorm_fwd(const py::handle &input, const py::handle &w py::object quantize(const at::Tensor &tensor, py::handle quantizer, const py::object &output, std::optional noop_flag); +// NVFP4-only split-phase quantize: compute amax, coalesce allreduce externally, then cast. +py::object compute_amax_nvfp4(const at::Tensor &tensor, py::handle quantizer, + const py::object &output); +py::object quantize_cast_only_nvfp4(const at::Tensor &tensor, py::handle quantizer, + const py::object &output, + std::optional noop_flag); + +// NVFP4-only multi-tensor amax: fuses N per-expert (zero_amax + amax + D2D replicate) +// chains into a single pair of kernel launches (one multi-zero + one multi-amax) that +// writes amax into every output's rowwise AND columnwise buffers. Outputs must be +// pre-allocated; amax is written in place, no return. +void compute_multi_amax_nvfp4(const std::vector &tensor_list, + std::vector quantizer_list, + const std::vector &output_list); + py::object dequantize(const py::handle &input, DType otype); py::object group_quantize(const at::Tensor &tensor, py::handle quantizer, const size_t num_tensors, diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index f8f793f036..e2602ed133 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -80,6 +80,148 @@ py::object quantize(const at::Tensor &tensor, py::handle quantizer, const py::ob return output_py; } +/*! @brief NVFP4-only: compute local amax into `output`'s amax buffers, no cast, no allreduce. + * + * Pair with an external coalesced allreduce of the returned amax tensors, + * then call `quantize_cast_only_nvfp4` to finish the cast. + */ +py::object compute_amax_nvfp4(const at::Tensor &tensor, py::handle quantizer, + const py::object &output) { + NVTE_CHECK(detail::IsNVFP4Quantizers(quantizer.ptr()), + "compute_amax_nvfp4 requires an NVFP4Quantizer"); + auto quantizer_cpp = convert_quantizer(quantizer); + auto *nvfp4_quantizer = dynamic_cast(quantizer_cpp.get()); + NVTE_CHECK(nvfp4_quantizer != nullptr, "Failed to cast quantizer to NVFP4Quantizer"); + + auto input_contiguous = tensor.contiguous(); + auto input_cpp = makeTransformerEngineTensor(input_contiguous); + + TensorWrapper output_cpp; + py::object output_py; + if (output.is_none()) { + const auto shape = get_tensor_shape(input_cpp); + const auto fake_dtype = input_cpp.dtype(); + std::tie(output_cpp, output_py) = quantizer_cpp->create_tensor(shape, fake_dtype); + } else { + std::tie(output_cpp, output_py) = quantizer_cpp->convert_and_update_tensor(output); + } + + nvfp4_quantizer->compute_amax_only(input_cpp, output_cpp); + return output_py; +} + +/*! @brief NVFP4-only: cast to FP4 using pre-reduced amax in `output`'s amax buffers. + * + * Skips both local amax compute and the internal allreduce. Caller must have + * already populated `output`'s amax via compute_amax_nvfp4 + coalesced allreduce. + */ +py::object quantize_cast_only_nvfp4(const at::Tensor &tensor, py::handle quantizer, + const py::object &output, + std::optional noop_flag) { + NVTE_CHECK(detail::IsNVFP4Quantizers(quantizer.ptr()), + "quantize_cast_only_nvfp4 requires an NVFP4Quantizer"); + auto quantizer_cpp = convert_quantizer(quantizer); + auto *nvfp4_quantizer = dynamic_cast(quantizer_cpp.get()); + NVTE_CHECK(nvfp4_quantizer != nullptr, "Failed to cast quantizer to NVFP4Quantizer"); + + auto input_contiguous = tensor.contiguous(); + auto input_cpp = makeTransformerEngineTensor(input_contiguous); + + TensorWrapper output_cpp; + py::object output_py; + if (output.is_none()) { + const auto shape = get_tensor_shape(input_cpp); + const auto fake_dtype = input_cpp.dtype(); + std::tie(output_cpp, output_py) = quantizer_cpp->create_tensor(shape, fake_dtype); + } else { + std::tie(output_cpp, output_py) = quantizer_cpp->convert_and_update_tensor(output); + } + + std::optional noop_flag_cpp; + if (noop_flag.has_value()) { + noop_flag_cpp = makeTransformerEngineTensor(*noop_flag); + } + + nvfp4_quantizer->quantize_cast_only(input_cpp, output_cpp, noop_flag_cpp); + return output_py; +} + +/*! @brief NVFP4-only: compute amax for N input tensors in a single launch. + * + * Each output's rowwise AND columnwise amax buffers are populated directly by the + * kernel (atomicMaxFloat), fusing the per-expert zero_amax + amax_kernel + D2D + * replicate chain into two multi-tensor launches. Caller pairs this with an + * external coalesced allreduce and then N calls to quantize_cast_only_nvfp4. + * + * Amax is written into the outputs passed in via output_list; no return value is + * needed — caller already holds references to those objects. + */ +void compute_multi_amax_nvfp4(const std::vector &tensor_list, + std::vector quantizer_list, + const std::vector &output_list) { + const size_t num_tensors = tensor_list.size(); + NVTE_CHECK(num_tensors > 0, "compute_multi_amax_nvfp4 requires at least one tensor"); + NVTE_CHECK(quantizer_list.size() == num_tensors, + "compute_multi_amax_nvfp4: quantizer_list size mismatch"); + NVTE_CHECK(output_list.size() == num_tensors, + "compute_multi_amax_nvfp4: output_list size mismatch"); + + // Locals held for the duration of this call (destroyed at function return). + // TensorWrappers only hold NVTETensor handles (opaque indexes into a global pool + // released by ~TensorWrapper); they do NOT reference quantizer_cpp or py::object, + // so we do not need to preserve quantizer unique_ptrs past this scope. + std::vector input_contiguous; + input_contiguous.reserve(num_tensors); + std::vector input_wrappers; + input_wrappers.reserve(num_tensors); + std::vector output_wrappers; + output_wrappers.reserve(num_tensors); + + std::vector inputs_nvte; + std::vector outputs_nvte; + inputs_nvte.reserve(num_tensors); + outputs_nvte.reserve(num_tensors); + + for (size_t i = 0; i < num_tensors; ++i) { + NVTE_CHECK(detail::IsNVFP4Quantizers(quantizer_list[i].ptr()), + "compute_multi_amax_nvfp4: quantizer[", i, "] is not an NVFP4Quantizer"); + auto quantizer_cpp = convert_quantizer(quantizer_list[i]); + auto *nvfp4_quantizer = dynamic_cast(quantizer_cpp.get()); + NVTE_CHECK(nvfp4_quantizer != nullptr && !nvfp4_quantizer->with_rht, + "compute_multi_amax_nvfp4 requires NVFP4Quantizer with with_rht=false (idx=", i, + ")"); + + input_contiguous.emplace_back(tensor_list[i].contiguous()); + input_wrappers.emplace_back(makeTransformerEngineTensor(input_contiguous.back())); + + TensorWrapper out_cpp; + py::object out_py; + NVTE_CHECK(!output_list[i].is_none(), + "compute_multi_amax_nvfp4: output_list[", i, "] is None; caller must pre-allocate"); + std::tie(out_cpp, out_py) = quantizer_cpp->convert_and_update_tensor(output_list[i]); + + NVTE_CHECK(out_cpp.get_amax().data_ptr != nullptr || + out_cpp.get_columnwise_amax().data_ptr != nullptr, + "compute_multi_amax_nvfp4: output[", i, "] has no amax buffer"); + + output_wrappers.emplace_back(std::move(out_cpp)); + // quantizer_cpp and out_py are released here at end-of-iteration. + + if (input_wrappers.back().numel() == 0) continue; + inputs_nvte.push_back(input_wrappers.back().data()); + outputs_nvte.push_back(output_wrappers.back().data()); + } + + if (inputs_nvte.empty()) return; + + QuantizationConfigWrapper quant_config; + auto stream = at::cuda::getCurrentCUDAStream(); + NVTE_SCOPED_GIL_RELEASE({ + nvte_multi_compute_amax(inputs_nvte.data(), outputs_nvte.data(), inputs_nvte.size(), + quant_config, stream); + }); +} + namespace { // helper functions for NVFP4 grouped quantization (cuda graph safe with shapes stored in device without D2H copy) diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 8302a13010..2a9281bc78 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -137,6 +137,16 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { NVTE_DECLARE_COMMON_PYBIND11_HANDLES(m) m.def("quantize", transformer_engine::pytorch::quantize, py::arg("tensor"), py::arg("quantizer"), py::arg("output") = py::none(), py::arg("noop") = py::none()); + m.def("compute_amax_nvfp4", transformer_engine::pytorch::compute_amax_nvfp4, + "NVFP4: compute local amax into output's amax buffers; no cast, no allreduce", + py::arg("tensor"), py::arg("quantizer"), py::arg("output") = py::none()); + m.def("quantize_cast_only_nvfp4", transformer_engine::pytorch::quantize_cast_only_nvfp4, + "NVFP4: cast using pre-reduced amax in output's amax buffers; skips amax compute and allreduce", + py::arg("tensor"), py::arg("quantizer"), py::arg("output") = py::none(), + py::arg("noop") = py::none()); + m.def("compute_multi_amax_nvfp4", transformer_engine::pytorch::compute_multi_amax_nvfp4, + "NVFP4: fused multi-tensor amax compute (writes both rowwise+columnwise amax per output)", + py::arg("tensor_list"), py::arg("quantizer_list"), py::arg("output_list")); m.def("dequantize", &transformer_engine::pytorch::dequantize, "Dequantize", py::arg("input"), py::arg("otype")); m.def("group_quantize", transformer_engine::pytorch::group_quantize, py::arg("tensor"), diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index 0214f7ff71..ac8fa26bc3 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -2121,7 +2121,7 @@ std::pair NVFP4Quantizer::convert_and_update_tensor( void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& out, const std::optional& noop_flag, - bool compute_amax) { + bool compute_amax, bool skip_amax_reduction) { // Nothing to be done if input is empty if (input.numel() == 0) { return; @@ -2225,7 +2225,7 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou } // amax reduction - if (this->with_amax_reduction) { + if (this->with_amax_reduction && !skip_amax_reduction) { std::vector amax_tensors; // push amax tensors inside if they need to be reduced auto make_amax_tensor = [](void* data_ptr) { @@ -2378,6 +2378,54 @@ void NVFP4Quantizer::quantize_with_amax(TensorWrapper& input, TensorWrapper& out this->quantize_impl(input, out, std::nullopt, false); } +void NVFP4Quantizer::compute_amax_only(const TensorWrapper& input, TensorWrapper& out) { + // Nothing to be done if input is empty + if (input.numel() == 0) { + return; + } + + // Only the non-RHT path is supported for the split-phase API today. + // RHT path's amax depends on the RHT-rotated view, which is produced + // alongside the cast; decoupling amax from cast is not meaningful there. + NVTE_CHECK(!this->with_rht, + "NVFP4Quantizer::compute_amax_only does not support with_rht=true"); + + auto stream = at::cuda::getCurrentCUDAStream(); + + QuantizationConfigWrapper quant_config; + quant_config.set_nvfp4_2d_quantization(this->with_2d_quantization); + + // Mirror the compute-amax block of quantize_impl exactly. + auto rowwise_amax_ptr = out.get_amax().data_ptr; + auto columnwise_amax_ptr = out.get_columnwise_amax().data_ptr; + void* amax_ptr = rowwise_amax_ptr != nullptr ? rowwise_amax_ptr : columnwise_amax_ptr; + NVTE_CHECK(amax_ptr != nullptr, "Could not find amax pointer"); + + out.set_amax(amax_ptr, DType::kFloat32, std::vector{1}); + NVTE_SCOPED_GIL_RELEASE( + { nvte_compute_amax_with_config(input.data(), out.data(), quant_config, stream); }); + out.set_amax(rowwise_amax_ptr, DType::kFloat32, std::vector{1}); + + // Replicate amax into whichever of rowwise/columnwise slots were requested. + if (rowwise_amax_ptr != amax_ptr && rowwise_amax_ptr != nullptr) { + NVTE_CHECK_CUDA(cudaMemcpyAsync(rowwise_amax_ptr, amax_ptr, sizeof(float), + cudaMemcpyDeviceToDevice, stream)); + } + if (columnwise_amax_ptr != amax_ptr && columnwise_amax_ptr != nullptr) { + NVTE_CHECK_CUDA(cudaMemcpyAsync(columnwise_amax_ptr, amax_ptr, sizeof(float), + cudaMemcpyDeviceToDevice, stream)); + } +} + +void NVFP4Quantizer::quantize_cast_only(const TensorWrapper& input, TensorWrapper& out, + const std::optional& noop_flag) { + // Amax is expected to already live in out's amax buffers (e.g. from + // compute_amax_only + an external coalesced allreduce). Skip both local + // amax compute and the internal allreduce. + this->quantize_impl(input, out, noop_flag, /*compute_amax=*/false, + /*skip_amax_reduction=*/true); +} + std::vector NVFP4Quantizer::get_scale_shape(const std::vector& shape, bool columnwise) const { size_t numel = 1; diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py index f269e21b8c..cea07eb6d5 100644 --- a/transformer_engine/pytorch/distributed.py +++ b/transformer_engine/pytorch/distributed.py @@ -6,7 +6,7 @@ from __future__ import annotations from collections.abc import Iterable -from contextlib import contextmanager, AbstractContextManager, ContextDecorator +from contextlib import contextmanager, AbstractContextManager, ContextDecorator, nullcontext from functools import lru_cache from dataclasses import dataclass import math @@ -908,7 +908,7 @@ def fork(self, name: str = "model-parallel-rng"): def reduce_scatter_along_first_dim( - inp: torch.Tensor, tp_group: dist_group_type, async_op: bool = False + inp: torch.Tensor, tp_group: dist_group_type, async_op: bool = False, output: torch.Tensor = None ) -> Tuple[torch.Tensor, Optional[torch.distributed.Work]]: """Reduce-scatter the input tensor across model parallel group.""" world_size = get_distributed_world_size(tp_group) @@ -916,14 +916,15 @@ def reduce_scatter_along_first_dim( if world_size == 1: return inp, None - dim_size = list(inp.size()) - assert ( - dim_size[0] % world_size == 0 - ), "First dimension of the tensor should be divisible by tensor parallel size" + if output is None: + dim_size = list(inp.size()) + assert ( + dim_size[0] % world_size == 0 + ), "First dimension of the tensor should be divisible by tensor parallel size" - dim_size[0] = dim_size[0] // world_size + dim_size[0] = dim_size[0] // world_size - output = torch.empty(dim_size, dtype=inp.dtype, device=torch.cuda.current_device()) + output = torch.empty(dim_size, dtype=inp.dtype, device=torch.cuda.current_device()) handle = torch.distributed.reduce_scatter_tensor( output, inp.contiguous(), group=tp_group, async_op=async_op ) @@ -1252,12 +1253,16 @@ def _post_process_nvfp4_gather( handle.wait() handle = None - # Fix the interleaved transposed data from gathering along first dim. - out._columnwise_scale_inv = _swap_first_dims(columnwise_scale_inv_interleaved, world_size) - out._columnwise_data = _swap_first_dims(columnwise_data_interleaved, world_size) + # TODO + # # Fix the interleaved transposed data from gathering along first dim. + # out._columnwise_scale_inv = _swap_first_dims(columnwise_scale_inv_interleaved, world_size) + # out._columnwise_data = _swap_first_dims(columnwise_data_interleaved, world_size) + out._columnwise_scale_inv.copy_(_swap_first_dims(columnwise_scale_inv_interleaved, world_size)) + out._columnwise_data.copy_(_swap_first_dims(columnwise_data_interleaved, world_size)) - # Optionally pad the scaling inverse if needed. - out._columnwise_scale_inv = pad_columnwise_scale_inv(out._columnwise_scale_inv) + # # Optionally pad the scaling inverse if needed. + # out._columnwise_scale_inv = pad_columnwise_scale_inv(out._columnwise_scale_inv) + out._columnwise_scale_inv.copy_(pad_columnwise_scale_inv(out._columnwise_scale_inv)) @dataclass @@ -1271,17 +1276,20 @@ class _NVFP4AllGatherAsyncHandle: async_handle: torch.distributed.Work _synchronized: bool = False - def wait(self) -> None: - """Wait for the async operation to complete and post-process the tensor.""" - if self._synchronized: - return - self.async_handle.wait() + def post_process_nvfp4_gather(self) -> None: _post_process_nvfp4_gather( self.output, self.columnwise_data_interleaved, self.columnwise_scale_inv_interleaved, self.world_size, ) + + def wait(self) -> None: + """Wait for the async operation to complete and post-process the tensor.""" + if self._synchronized: + return + self.async_handle.wait() + self.post_process_nvfp4_gather() self._synchronized = True @@ -1292,6 +1300,8 @@ def _all_gather_nvfp4( async_op: bool = False, quantizer: NVFP4Quantizer, out_shape: Optional[list[int]] = None, + output_tensor = None, + grouped = False, ) -> tuple[NVFP4TensorStorage, Optional[torch.distributed.Work]]: """All-gather NVFP4 tensor along first dimension.""" @@ -1348,6 +1358,12 @@ def _all_gather_nvfp4( out = quantizer(out) return out, None + # Construct NVFP4 output tensor + if output_tensor is not None: + out = output_tensor + else: + out = quantizer.make_empty(out_shape, dtype=dtype, device=device) + # Cast input tensor to NVFP4 with required data if not isinstance(inp, NVFP4TensorStorage): inp = quantizer(inp) @@ -1360,17 +1376,19 @@ def _all_gather_nvfp4( ) inp = quantizer(inp.dequantize()) - # Construct NVFP4 output tensor - out = quantizer.make_empty(out_shape, dtype=dtype, device=device) - - # Coalesce NCCL collectives for gathering data and scale inverses. - with torch.distributed._coalescing_manager( - group=process_group, - device=device, - async_ops=async_op, - ) as gather_coalescing_manager: + if not grouped: + # Coalesce NCCL collectives for gathering data and scale inverses. + gather_coalescing_manager = torch.distributed._coalescing_manager( + group=process_group, + device=device, + async_ops=async_op, + ) + else: + gather_coalescing_manager = nullcontext() + with gather_coalescing_manager as coalesced_handle: # Gather NVFP4 data for row-wise usage + out_columnwise_data = None if quantizer.rowwise_usage: # Remove padding from NVFP4 scale-inverses @@ -1395,7 +1413,9 @@ def _all_gather_nvfp4( ) # Transfer amax to output. - out._amax_rowwise = inp._amax_rowwise + #TODO: jiemingz + # out._amax_rowwise = inp._amax_rowwise + out._amax_rowwise.copy_(inp._amax_rowwise) # Gather the transposed NVFP4 data along first dimension. Fix format later. if quantizer.columnwise_usage: @@ -1444,17 +1464,25 @@ def _all_gather_nvfp4( ) # Transfer amax to output. - out._amax_columnwise = inp._amax_columnwise + out._amax_columnwise.copy_(inp._amax_columnwise) - handle = gather_coalescing_manager if async_op else None + + handle = coalesced_handle if async_op else None # Fixes interleaved data for transposed tensor/scale inv and pads scale inv if needed. - if async_op and quantizer.columnwise_usage: - handle = _NVFP4AllGatherAsyncHandle( - out, out_columnwise_data, out_scale_inv, world_size, handle - ) - elif quantizer.columnwise_usage: - _post_process_nvfp4_gather(out, out_columnwise_data, out_scale_inv, world_size, handle) + if quantizer.columnwise_usage: + if async_op or grouped: + # Defer post-processing: either the async op hasn't completed yet, or an + # external coalescing manager owns the NCCL ops and hasn't flushed them. + inner_handle = handle if async_op else None + handle = _NVFP4AllGatherAsyncHandle( + out, out_columnwise_data, out_scale_inv, world_size, inner_handle + ) + else: + _post_process_nvfp4_gather(out, out_columnwise_data, out_scale_inv, world_size, handle) + else: + if handle is not None: + handle.output = out return out, handle @@ -1466,6 +1494,8 @@ def _all_gather_mxfp8( async_op: bool = False, quantizer: MXFP8Quantizer, out_shape: Optional[list[int]] = None, + output_tensor: torch.Tensor = None, + grouped: bool = False, ) -> tuple[MXFP8TensorStorage, Optional[torch.distributed.Work]]: """All-gather MXFP8 tensor along first dimension.""" @@ -1528,15 +1558,22 @@ def _all_gather_mxfp8( inp = quantizer(inp.dequantize()) # Construct MXFP8 output tensor - out = quantizer.make_empty(out_shape, dtype=dtype, device=device) + if output_tensor is not None: + out = output_tensor + else: + out = quantizer.make_empty(out_shape, dtype=dtype, device=device) - # Coalesce NCCL collectives - with torch.distributed._coalescing_manager( - group=process_group, - device=device, - async_ops=async_op, - ) as coalescing_manager: + if not grouped: + # Coalesce NCCL collectives for gathering data and scale inverses. + gather_coalescing_manager = torch.distributed._coalescing_manager( + group=process_group, + device=device, + async_ops=async_op, + ) + else: + gather_coalescing_manager = nullcontext() + with gather_coalescing_manager as coalesced_handle: # Gather MXFP8 data for row-wise usage if quantizer.rowwise_usage: @@ -1583,7 +1620,7 @@ def _all_gather_mxfp8( group=process_group, ) - handle = coalescing_manager if async_op else None + handle = coalesced_handle if async_op else None return out, handle @@ -1592,6 +1629,8 @@ def gather_along_first_dim( process_group: dist_group_type, async_op: bool = False, quantizer: Optional[Quantizer] = None, + output_tensor: torch.Tensor = None, + grouped: bool = False, ) -> tuple[torch.Tensor, Optional[torch.distributed.Work]]: """ All-gather tensors and concatenate along first dimension. @@ -1679,6 +1718,8 @@ def gather_along_first_dim( async_op=async_op, quantizer=quantizer, out_shape=out_shape, + output_tensor=output_tensor, + grouped=grouped, ) # NVFP4 case @@ -1690,6 +1731,8 @@ def gather_along_first_dim( async_op=async_op, quantizer=quantizer, out_shape=out_shape, + output_tensor=output_tensor, + grouped=grouped, ) # High-precision communication for quantized tensors @@ -1719,19 +1762,20 @@ def gather_along_first_dim( inp = inp.dequantize() # Communication for plain PyTorch tensors - out = torch.empty( - out_shape, - dtype=inp.dtype, - device=inp.device, - memory_format=torch.contiguous_format, - ) + if output_tensor is None: + output_tensor = torch.empty( + out_shape, + dtype=inp.dtype, + device=inp.device, + memory_format=torch.contiguous_format, + ) handle = torch.distributed.all_gather_into_tensor( - out, + output_tensor, inp.contiguous(), group=process_group, async_op=async_op, ) - return out, handle + return output_tensor, handle # Global cache to store symmetric memory tensors diff --git a/transformer_engine/pytorch/module/extended_tensor_parallelism.py b/transformer_engine/pytorch/module/extended_tensor_parallelism.py new file mode 100644 index 0000000000..84dbe05eeb --- /dev/null +++ b/transformer_engine/pytorch/module/extended_tensor_parallelism.py @@ -0,0 +1,1724 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +from collections import defaultdict +from contextlib import nullcontext +from typing import Dict, List, Optional +from enum import Enum +from dataclasses import dataclass, field +import math +import re +import torch + +from ..distributed import ( + gather_along_first_dim, + reduce_scatter_along_first_dim, + _NVFP4AllGatherAsyncHandle +) +from ..quantized_tensor import QuantizedTensor +from ..tensor import NVFP4TensorStorage, MXFP8TensorStorage +from ..tensor.mxfp8_tensor import MXFP8Quantizer +from ..utils import nvtx_range_pop, nvtx_range_push, round_up_to_nearest_multiple +from ..constants import NVFP4_BLOCK_SCALING_SIZE, MXFP8_BLOCK_SCALING_SIZE +from .base import get_dummy_wgrad + +import transformer_engine_torch as tex + +DEBUG_TENSOR = None + + +class ETPChain(str, Enum): + """Prefetch chain identifier for an ETPShardedParam. + + GRAPHED — fwd/bwd captured by a CUDA graph (MLM _CudaGraphRunner). + UNGRAPHED — fwd/bwd runs eagerly; includes embedding/output_layer and + routed grouped experts always, plus router/shared_experts + when their scope tag is not in cuda_graph_scope. + + Chains never cross-link (prev_w/next_w stay within one chain). CG + disabled → single UNGRAPHED chain; full-iteration graph → single GRAPHED. + """ + GRAPHED = "ETP_graphed" + UNGRAPHED = "ETP_ungraphed" + + +# Module-level cuda_graph_scope, set by MLM at init via set_cuda_graph_scope(). +# None or empty → CG is disabled; every ETP param classifies as UNGRAPHED. +# Value is a set of scope tags; e.g. {"mamba","attn","moe_router"}. +_CUDA_GRAPH_SCOPE: Optional[set] = None +# Whether shared_experts are run with overlap (cannot be captured). When True, +# shared_experts stay UNGRAPHED regardless of moe_router scope inclusion, matching +# the transformer_layer.py guard that excludes them from the captured submodules. +_MOE_SHARED_EXPERT_OVERLAP: bool = False + + +def set_cuda_graph_scope(scope, moe_shared_expert_overlap: bool = False): + """Record the active cuda_graph_scope for ETP chain classification. + + Called by MLM at init, BEFORE classify_etp_chains(). ``scope`` may be + None, an empty iterable (CG disabled), or an iterable of scope tags. + """ + global _CUDA_GRAPH_SCOPE, _MOE_SHARED_EXPERT_OVERLAP + _CUDA_GRAPH_SCOPE = set(scope) if scope else None + _MOE_SHARED_EXPERT_OVERLAP = bool(moe_shared_expert_overlap) + + +def _classify_param_chain(param_name: str) -> 'ETPChain': + """Classify an ETPShardedParam by name + active cuda_graph_scope. + + embedding / output_layer are always UNGRAPHED. Other kinds (mamba mixer, + self/cross_attention, shared_experts, routed experts) are GRAPHED iff + their scope tag is present in cuda_graph_scope; otherwise UNGRAPHED. + """ + n = param_name + + # Always ungraphed — embedding and output_layer live outside any CG runner. + if "embedding" in n or "output_layer" in n: + return ETPChain.UNGRAPHED + + scope = _CUDA_GRAPH_SCOPE + if not scope: + # CG disabled: every ETP param goes to the single UNGRAPHED chain. + return ETPChain.UNGRAPHED + + if ".mlp.shared_experts." in n: + if _MOE_SHARED_EXPERT_OVERLAP: + return ETPChain.UNGRAPHED + return ETPChain.GRAPHED if ("moe" in scope or "moe_router" in scope) else ETPChain.UNGRAPHED + + if ".mlp.experts." in n: + return ETPChain.GRAPHED if "moe" in scope else ETPChain.UNGRAPHED + + if ".self_attention." in n or ".cross_attention." in n: + return ETPChain.GRAPHED if "attn" in scope else ETPChain.UNGRAPHED + + if ".mixer." in n: + return ETPChain.GRAPHED if "mamba" in scope else ETPChain.UNGRAPHED + + return ETPChain.UNGRAPHED + + +def classify_etp_chains(model) -> None: + """Walk model.named_parameters() and set chain_id on every ETPShardedParam. + + Call once at init, AFTER set_cuda_graph_scope() and BEFORE the first fwd + of any graphed param. Raises if an already chain-initialized param would + be reclassified into a different chain (its prev/next links are already + wired into the wrong list). + """ + conflicts = [] + for name, param in model.named_parameters(): + if not isinstance(param, ETPShardedParam): + continue + target = _classify_param_chain(name).value + if param.prefetch_initialized and param.chain_id != target: + conflicts.append((name, param.chain_id, target)) + continue + param.chain_id = target + + # Bwd-prefetch opt-out: embedding.word_embeddings.weight does not need + # an AG in the bwd pass (its wgrad is a scatter-add on sharded rows + # and its input has no dgrad). Skipping its bwd AG saves one collective. + if "embedding" in name: + param._need_weight_prefetch_bwd = False + if conflicts: + raise RuntimeError( + "classify_etp_chains: the following params were already chain-initialized " + "with a different chain_id than the classifier would assign — this means " + "their chain links are already wired into the wrong list. Move classification " + "earlier in init. Conflicts: " + + ", ".join(f"{n}: {old!r}->{new!r}" for n, old, new in conflicts[:3]) + + ("..." if len(conflicts) > 3 else "") + ) + + +class ETPWeightState(Enum): + NONE = "NONE" # Sharded, no pending operation + ASYNC_WAIT = "ASYNC_WAIT" # Async all-gather in progress + DATA_READY = "DATA_READY" # Async all-gather complete, result in cache + DATA_READY_SYNC = "DATA_READY_SYNC" # Sync all-gather complete, result in cache + + +_STATE_TRANSITIONS = { + ETPWeightState.NONE: {ETPWeightState.ASYNC_WAIT, ETPWeightState.DATA_READY_SYNC}, + ETPWeightState.ASYNC_WAIT: {ETPWeightState.DATA_READY}, + ETPWeightState.DATA_READY: {ETPWeightState.NONE}, + ETPWeightState.DATA_READY_SYNC: {ETPWeightState.NONE}, +} + + +# Global ETP buffer cache (persists across clear(); never set to None after creation). +_ETP_CACHE = None +_ETP_PARAMS = [] + +# Global set of ETPShardedParam with in-flight async comms (AG or RS). +_inflight_comm_params: set = set() +_AG_STREAMS: Dict[str, torch.cuda.Stream] = {} +_RS_STREAMS: Dict[str, torch.cuda.Stream] = {} + +# Wgrad input buffer pool, keyed by (shape, dtype). UNGRAPHED-only: GRAPHED +# wgrad bufs need address stability for CG replay and are not pool-recycled. +_wgrad_buf_pool: Dict[tuple, list] = {} + + +def _wgrad_pool_get(shape: tuple, dtype: torch.dtype, device) -> torch.Tensor: + """Get a pool buffer or allocate fresh. Tagged so _wgrad_pool_put accepts + only pool-owned buffers — callers that don't use _wgrad_pool_get (e.g. + Megatron layers.py wgrad GEMM, aten F.embedding bwd) fall through to the + caching allocator on release.""" + key = (shape, dtype) + pool = _wgrad_buf_pool.get(key) + if pool: + buf = pool.pop() + else: + buf = torch.empty(shape, dtype=dtype, device=device, requires_grad=False) + buf._from_etp_wgrad_pool = True + return buf + + +def _wgrad_pool_put(buf: torch.Tensor): + """Return a pool-owned buffer for reuse (no-op for untagged buffers; see + _wgrad_pool_get).""" + if not getattr(buf, '_from_etp_wgrad_pool', False): + return + key = (tuple(buf.shape), buf.dtype) + if key not in _wgrad_buf_pool: + _wgrad_buf_pool[key] = [] + _wgrad_buf_pool[key].append(buf) + + +def _stream_key(chain_id: str, group) -> tuple: + """Key for the per-(chain, group) AG/RS stream dicts. + + Two partitioning axes: + - chain_id: captured (GRAPHED) vs eager (UNGRAPHED) ops must not share + a stream (eager ops would contaminate capture/replay state). + - group: independent NCCL communicators (e.g. ETP vs EETP) get their + own user-level stream to avoid cross-group serialization. + """ + return (chain_id, id(group) if group is not None else 0) + + +def get_ag_stream(chain_id: str = ETPChain.GRAPHED.value, group=None) -> torch.cuda.Stream: + """Return the ETP all-gather stream for (chain_id, group). See _stream_key.""" + key = _stream_key(chain_id, group) + if key not in _AG_STREAMS: + _AG_STREAMS[key] = torch.cuda.Stream() + return _AG_STREAMS[key] + + +def get_rs_stream(chain_id: str = ETPChain.GRAPHED.value, group=None) -> torch.cuda.Stream: + """Return the ETP reduce-scatter stream for (chain_id, group). See _stream_key.""" + key = _stream_key(chain_id, group) + if key not in _RS_STREAMS: + _RS_STREAMS[key] = torch.cuda.Stream() + return _RS_STREAMS[key] + + +def get_all_ag_streams() -> list: + """All AG streams created so far, across chains and groups.""" + return list(_AG_STREAMS.values()) + + +def get_all_rs_streams() -> list: + """All RS streams created so far, across chains and groups.""" + return list(_RS_STREAMS.values()) + + +def get_ag_streams_for_chain(chain_id: str) -> list: + """AG streams for one chain (all groups that chain has touched).""" + return [s for k, s in _AG_STREAMS.items() if k[0] == chain_id] + + +def get_rs_streams_for_chain(chain_id: str) -> list: + """RS streams for one chain (all groups that chain has touched).""" + return [s for k, s in _RS_STREAMS.items() if k[0] == chain_id] + +# Cached once per process: whether the TE build exposes the split-phase APIs. +_COALESCED_AMAX_TE_APIS_AVAILABLE = ( + hasattr(tex, "compute_amax_nvfp4") and hasattr(tex, "quantize_cast_only_nvfp4") +) + +# Tier-2: multi-tensor amax kernel fuses N per-expert (zero_amax + amax + D2D) chains +# into two multi-tensor kernel launches. Independent of Tier-1 coalesced allreduce. +_MULTI_AMAX_TE_API_AVAILABLE = hasattr(tex, "compute_multi_amax_nvfp4") + + +def _coalesced_amax_static_eligible(weights): + """Check whether the coalesced-amax path is applicable (NVFP4 only). + + Caller already gates on ETP_CONFIG.coalesce_amax_allreduce (False for + non-NVFP4). Here we additionally verify TE API availability, batch size, + quantizer type (must have amax reduction), and the RHT flag.""" + dbg = ETP_CONFIG.debug_numerics > 0 + if not _COALESCED_AMAX_TE_APIS_AVAILABLE: + if dbg: + print_rank_0("[ETP_DEBUG] coalesced_amax_static: REJECTED (TE APIs unavailable)") + return False + if len(weights) <= 1: + return False + has_amax = [getattr(w._quantizer, "with_amax_reduction", False) for w in weights] + if not all(has_amax): + if dbg: + qtypes = [type(w._quantizer).__name__ for w in weights[:3]] + print_rank_0( + f"[ETP_DEBUG] coalesced_amax_static: REJECTED " + f"(with_amax_reduction={has_amax[:3]}{'...' if len(has_amax)>3 else ''}, " + f"quantizer_types={qtypes}{'...' if len(weights)>3 else ''}, " + f"n_weights={len(weights)})" + ) + return False + has_rht = any(getattr(w._quantizer, "with_rht", False) for w in weights) + if has_rht: + if dbg: + print_rank_0("[ETP_DEBUG] coalesced_amax_static: REJECTED (with_rht=True)") + return False + if dbg: + qtypes = [type(w._quantizer).__name__ for w in weights[:3]] + print_rank_0( + f"[ETP_DEBUG] coalesced_amax_static: *** ACCEPTED *** " + f"(n_weights={len(weights)}, quantizer_types={qtypes}{'...' if len(weights)>3 else ''})" + ) + return True + + +def _quantize_with_coalesced_amax(weights, skip_weight_cast, cast_noop_flag): + """Replace the per-weight (compute_amax + allreduce + cast) loop with: + compute_amax loop → one coalesced allreduce → cast loop.""" + group = weights[0]._quantizer.amax_reduction_group + + # Materialize padded shards once; on padded last-rank get_padded_shard() + # launches an F.pad kernel, and we'd otherwise pay it twice per expert. + padded_shards = [w.get_padded_shard() for w in weights] + + # Phase 1: per-weight local amax into each w.quantized's amax buffers. + # Keep rowwise/columnwise both populated so the group allreduce sees + # whichever the consumer GEMM will read. + for w in weights: + w._quantizer.set_usage(rowwise=True, columnwise=True) + if _MULTI_AMAX_TE_API_AVAILABLE: + # Tier-2: single multi-tensor launch writes both rowwise and columnwise + # amax directly (no per-expert D2D replicate), fusing N per-expert chains. + # Reuse the _cached_quantizers list already populated by _all_gather_weight + anchor = weights[0] + quantizer_list = anchor._cached_quantizers + if quantizer_list is None: + quantizer_list = [w._quantizer for w in weights] + anchor._cached_quantizers = quantizer_list + tex.compute_multi_amax_nvfp4( + padded_shards, + quantizer_list, + [w.quantized for w in weights], + ) + else: + for w, shard in zip(weights, padded_shards): + tex.compute_amax_nvfp4( + tensor=shard, + quantizer=w._quantizer, + output=w.quantized, + ) + + # Phase 2: one coalesced allreduce across every weight's amax tensors. + amax_tensors = [] + for w in weights: + rw = w.quantized._amax_rowwise + cw = w.quantized._amax_columnwise + if rw is not None: + amax_tensors.append(rw) + if cw is not None and (rw is None or cw.data_ptr() != rw.data_ptr()): + amax_tensors.append(cw) + torch.distributed.all_reduce_coalesced( + amax_tensors, + op=torch.distributed.ReduceOp.MAX, + group=group, + ) + + # Phase 3: per-weight cast using the pre-reduced amax; skips the internal + # allreduce inside the quantizer. + for w, shard in zip(weights, padded_shards): + tex.quantize_cast_only_nvfp4( + tensor=shard, + quantizer=w._quantizer, + output=w.quantized, + noop=cast_noop_flag, + ) + w.did_cast_to_low_precision = True + + +@dataclass +class ETPConfig: + """Global configuration for Extended Tensor Parallelism.""" + pad_for_alignment: int = 16 + check_param_states: bool = False + weight_prefetch: bool = True + # When True and the weight list in _all_gather_weight contains >1 NVFP4 + # shards that share an amax reduction group, coalesce their per-expert + # amax allreduces into a single NCCL call. Requires TE with + # tex.compute_amax_nvfp4 / tex.quantize_cast_only_nvfp4; the eligibility + # guard in _coalesced_amax_static_eligible falls back to the per-weight + # path when either binding is missing. + coalesce_amax_allreduce: bool = True + # Log numeric diagnostics for the first N AG/RS calls per param. + # 0 = off; 3 = good default for triage (covers iter 1-2 fwd+bwd). + debug_numerics: int = 0 + +ETP_CONFIG = ETPConfig() + +# --------------------------------------------------------------------------- +# Debug helpers (gated by ETP_CONFIG.debug_numerics > 0) +# --------------------------------------------------------------------------- +_etp_debug_counts: Dict[tuple, int] = {} + +def _etp_dbg_capturing(): + """True when a CUDA graph is being captured — D2H syncs are forbidden.""" + return torch.cuda.is_current_stream_capturing() + +def _etp_dbg_should_log(param_name, label): + if ETP_CONFIG.debug_numerics <= 0 or _etp_dbg_capturing(): + return False + key = (param_name, label) + count = _etp_debug_counts.get(key, 0) + if count >= ETP_CONFIG.debug_numerics: + return False + _etp_debug_counts[key] = count + 1 + return True + +def _etp_dbg_tensor(name, t): + """One-line NaN/Inf summary for a BF16/FP32 tensor.""" + if t is None: + return f"{name}=None" + if t.numel() == 0: + return f"{name}:{list(t.shape)},empty" + if not t.is_floating_point(): + return f"{name}:non-float({t.dtype})" + has_nan = bool(torch.isnan(t).any()) + has_inf = bool(torch.isinf(t).any()) + amax = t.abs().max().item() + tag = " ***BAD***" if (has_nan or has_inf) else "" + return f"{name}:{list(t.shape)},amax={amax:.4e},nan={has_nan},inf={has_inf}{tag}" + +def _etp_dbg_quantized(name, qt): + """Multi-line check of a quantized tensor's metadata fields.""" + if qt is None: + return f"{name}=None" + md = qt.get_metadata() + parts = [f"{name}:type={type(qt).__name__}"] + for k in ("rowwise_data", "columnwise_data"): + v = md.get(k) + parts.append(f" {k}={'shape=' + str(list(v.shape)) if v is not None else 'NONE'}") + for k in ("rowwise_scale_inv", "columnwise_scale_inv"): + v = md.get(k) + if v is not None and v.numel() == 0: + parts.append(f" {k}:{list(v.shape)},empty") + elif v is not None and v.is_floating_point(): + has_nan = bool(torch.isnan(v).any()) + has_inf = bool(torch.isinf(v).any()) + amax = v.abs().max().item() + tag = " ***BAD***" if (has_nan or has_inf) else "" + parts.append(f" {k}:{list(v.shape)},amax={amax:.4e},nan={has_nan},inf={has_inf}{tag}") + elif v is not None: + parts.append(f" {k}:{list(v.shape)},dtype={v.dtype}") + else: + parts.append(f" {k}:NONE") + return "\n".join(parts) + +def update_config(**kwargs): + """Update the global ETP configuration.""" + for key, value in kwargs.items(): + if not hasattr(ETP_CONFIG, key): + raise ValueError(f"Unknown ETP config option: {key}") + setattr(ETP_CONFIG, key, value) + + +def tag_etp_params_with_names(model): + """Populate _debug_name on every ETPShardedParam with its full dotted parameter name. + + Call once after model construction so the linking log prints human-readable names + instead of raw tensor ids. + """ + for name, param in model.named_parameters(): + if isinstance(param, ETPShardedParam): + param._debug_name = name + + +def wrap_module_params_etp(module, weight_names, etp_group, is_grouped=None): + """Shard and re-register all parameters of a module using ETP weight sharding.""" + if etp_group.size() == 1: + return + + etp_size = etp_group.size() + etp_rank = etp_group.rank() + + for idx, name in enumerate(weight_names): + param = getattr(module, name, None) + if param is None: + continue + + # delete the original parameter, which will be replaced by an ETP sharded one + delattr(module, name) + + if ETP_CONFIG.pad_for_alignment > 0: + # Pad the full tensor BEFORE sharding so every rank gets exactly + # shard_size rows and each shard's dim0 is alignment-divisible. + # Padding stays contiguous at the tail of the gathered result — + # no interleaved-padding reshuffle needed after all-gather. + alignment = ETP_CONFIG.pad_for_alignment * etp_size + tensor = param.data + dim0 = tensor.shape[0] + pad_length = (alignment - dim0 % alignment) % alignment if alignment > 0 else 0 + if pad_length > 0: + tensor = torch.nn.functional.pad(tensor, (0, 0, 0, pad_length)) + padded_dim0 = dim0 + pad_length + shard_size = padded_dim0 // etp_size + shard = tensor[etp_rank * shard_size : (etp_rank + 1) * shard_size] + etp_shard = ETPShardedParam(shard.clone()) + etp_shard.pad_length = pad_length + else: + shard_size = tensor.shape[0] // etp_group.size() + shard = tensor[etp_rank * shard_size: (etp_rank + 1) * shard_size] + etp_shard = ETPShardedParam(shard.clone()) + + if is_grouped: + etp_shard.expert_idx = idx + etp_shard.is_routed_expert = True + # Grouped routed experts are UNGRAPHED unless the "moe" scope captures + # them; classify_etp_chains() will fix this up at init time based on + # the actual cuda_graph_scope. We set UNGRAPHED here as a safe default. + etp_shard.chain_id = ETPChain.UNGRAPHED.value + etp_shard.group = etp_group + etp_shard.ps_size = etp_size + # register the newly sharded param back to the module + module._parameters[name] = etp_shard + + global _ETP_PARAMS + _ETP_PARAMS.append(etp_shard) + + if is_grouped: + allweights = [getattr(module, name) for name in weight_names] + allweights[0].weight_list = allweights + + +class ETPShardHandle: + + def __init__(self, handle, etp_shards, reduce_scatter=False): + self.handle = handle + self.etp_shards = etp_shards + self.reduce_scatter = reduce_scatter + _inflight_comm_params.add(etp_shards[0]) + + def wait(self): + if self.handle is not None: + self.handle.wait() + self.handle = None # Release NCCL Work and its C++ tensor references promptly + for w in self.etp_shards: + if self.reduce_scatter: + w._set_rs_state(ETPWeightState.DATA_READY) + else: + w._set_state(ETPWeightState.DATA_READY) + + _inflight_comm_params.discard(self.etp_shards[0]) + + +class ETPShardedParam(torch.nn.Parameter): + + _pending_rs_weight = None + _first_weight_flag = True + # Per-chain state: each chain_id (ETPChain.GRAPHED / ETPChain.UNGRAPHED) has + # its own linked list. Chains never cross-link: prev_w/next_w only connect + # params with the same chain_id. + _chain_state: Dict[str, dict] = {} + + @classmethod + def _get_chain_state(cls, chain_id: str) -> dict: + if chain_id not in cls._chain_state: + cls._chain_state[chain_id] = { + 'last_weight': None, + 'link_node_count': 0, + 'link_table_buffer': [], + 'link_table_flushed': False, + } + return cls._chain_state[chain_id] + + @classmethod + def _buffer_link_table_row(cls, prev: "ETPShardedParam", curr: "ETPShardedParam", chain: dict) -> None: + """Buffer one row of the prefetch-link table (flushed atomically on the second forward pass).""" + _W = 70 + + def _layer_id(name: str) -> str: + m = re.search(r"\d+", name) + return m.group() if m else "-" + + chain['link_node_count'] += 1 + if chain['link_node_count'] == 1: + chain_id = getattr(curr, 'chain_id', ETPChain.UNGRAPHED.value) + chain['link_table_buffer'].append( + f"\n[{chain_id} chain]" + f"\n{'node_id':>7} | {'layer_id':>8} | {'curr_weight_name':<{_W}} | prev_weight_name" + f"\n{'-'*7}-+-{'-'*8}-+-{'-'*_W}-+-{'-'*_W}" + ) + # Seed weight (first ETP param) as row 0 + chain['link_table_buffer'].append( + f"{'0':>7} | {_layer_id(prev._debug_name):>8} | {prev._debug_name:<{_W}} | -" + ) + chain['link_table_buffer'].append( + f"{chain['link_node_count']:>7} | {_layer_id(curr._debug_name):>8} | " + f"{curr._debug_name:<{_W}} | {prev._debug_name}" + ) + + @staticmethod + def __new__(cls, tensor, *args, **kwargs): + requires_grad = kwargs.get('requires_grad', True) + return super(ETPShardedParam, cls).__new__(cls, tensor, requires_grad=requires_grad) + + def __init__(self, x, *args, **kwargs): + super().__init__() + + # all gather + self.state = ETPWeightState.NONE + self._ag_ticket_fwd = None + self._ag_ticket_bwd = None + self._prefetch_handle = None + self._need_weight_prefetch = True + # Per-direction prefetch opt-outs. Default True. The embedding weight + # never needs an AG during bwd (its wgrad is a scatter-add indexed by + # token ids, and its input is non-differentiable, so no dgrad either). + # classify_etp_chains() sets this to False for embedding.word_embeddings.weight. + self._need_weight_prefetch_bwd = True + self.ag_event = torch.cuda.Event(external=True) + # DDP backward hook (set by register_grad_accum_hook) + self._grad_accum_node = None + self._grad_accum_hook = None + # Quantization + self._quantizer = None + self.did_cast_to_low_precision = False + self.quantized = None + # Prefetching linked list + self.prefetch_initialized = False + self.next_w = None + self.prev_w = None + # Chain identity (ETPChain.GRAPHED / ETPChain.UNGRAPHED). Defaults to + # UNGRAPHED as a safe fallback; classify_etp_chains(model) walks the + # model at init time (after set_cuda_graph_scope) and reclassifies + # based on param name + active cuda_graph_scope. + self.chain_id = ETPChain.UNGRAPHED.value + # Grouped gemm + self.is_routed_expert = False + self.expert_idx = None + self.group = None + self.weight_list = None + # Reduce-scatter state (set during wgrad_reduce_scatter) + self.rs_state = ETPWeightState.NONE + self.wgrad_rs = None + self._wgrad_rs_handle = None + self.rs_event = torch.cuda.Event(external=True) + self._rs_ticket = None + # Padding + self.pad_length = 0 + # Debug + self._debug_name = "" + # Hot-path caches (populated lazily on first use). chain_id/group are + # set after __init__, so we can't resolve streams eagerly here. + self._cached_ag_stream = None + self._cached_rs_stream = None + self._cached_quantizers = None + self._cached_dtypes = None + self._cached_etp_group = None + + def setup(self, weight_quantizer=None): + """Set quantizer and create quantized shard.""" + + if self._quantizer is None: + def _configure_quantizer(q, group): + q = q.copy() + if hasattr(q, 'with_amax_reduction'): + q.with_amax_reduction = True + q.amax_reduction_group = group + q.internal = False + # MXFP8 scales must stay in compact (unswizzled) layout so that + # per-shard scale_inv can be all-gathered via byte concatenation. + # GEMM-swizzled scales from independent shards don't compose into + # a valid swizzled layout for the full tensor after AG. + q.optimize_for_gemm = not isinstance(q, MXFP8Quantizer) + return q + + weights = self.weight_list if self.is_routed_expert and self.weight_list is not None else [self] + for quantizer, weight in zip(weight_quantizer, weights): + if quantizer is None: + continue + + weight._quantizer = _configure_quantizer(quantizer, weight.group) + weight.quantized = weight._quantizer.quantize(weight.get_padded_shard()) + weight.quantized.is_routed_expert = getattr(weight, 'is_routed_expert', False) + + @property + def _weights(self): + """Return the list of individual weight shards (self for non-routed, weight_list for routed).""" + weights = self.weight_list if self.is_routed_expert else [self] + # Safety: all weights must be in the same state. + assert all(w.state == weights[0].state for w in weights) + return list(weights) + + @property + def _unsharded_shape_padded(self): + out_shape = list(self.size()) + out_shape[0] = out_shape[0] * self.group.size() + return tuple(out_shape) + + @property + def _unsharded_shape(self): + out_shape = list(self._unsharded_shape_padded) + out_shape[0] -= self.pad_length + return tuple(out_shape) + + @property + def _sharded_padded_shape(self): + return tuple(self.size()) + + def get_padded_shard(self): + return self + + def _set_state(self, new_state: ETPWeightState): + # if ETP_CONFIG.check_param_states: + # assert new_state in _STATE_TRANSITIONS[self.state], \ + # f"Invalid state transition: {self.state} -> {new_state}" + self.state = new_state + + def _set_rs_state(self, new_state: ETPWeightState): + # if ETP_CONFIG.check_param_states: + # assert new_state in _STATE_TRANSITIONS[self.rs_state], \ + # f"Invalid state transition: {self.rs_state} -> {new_state}" + self.rs_state = new_state + + def _get_cache_key(self, dtype, fwd: bool, reduce_scatter: bool) -> tuple: + """Build cache key using output shape + dtype. + + Weights with matching gathered shape and dtype share a buffer. + For expert weights gathered in parallel, self.expert_idx distinguishes them so + each gets a distinct buffer, while same-indexed experts across layers share. + """ + + if not isinstance(dtype, torch.dtype): + return (self._unsharded_shape_padded, dtype, fwd, not fwd, self.expert_idx, reduce_scatter) + return (self._unsharded_shape_padded, dtype, self.expert_idx, reduce_scatter) + + def _quantize_if_needed(self, skip_weight_cast=False, cast_noop_flag=None): + """Re-quantize sharded weight into existing buffer. Returns quantized weight or self.""" + if self._quantizer is None: + self.did_cast_to_low_precision = False + return self + + self._quantizer.set_usage(rowwise=True, columnwise=True) + if skip_weight_cast is False or cast_noop_flag is not None: + tex.quantize( + tensor=self.get_padded_shard(), + quantizer=self._quantizer, + output=self.quantized, + noop=cast_noop_flag, + ) + self.did_cast_to_low_precision = True + + return self.quantized + + def _strip_padding(self, tensor): + if self.pad_length == 0: + return tensor + + if isinstance(tensor, QuantizedTensor): + assert isinstance(tensor, (NVFP4TensorStorage, MXFP8TensorStorage)), \ + f"Unsupported quantized tensor type for ETP padding: {type(tensor)}" + + metadata = tensor.get_metadata() + if metadata.get("rowwise_data") is not None: + metadata["rowwise_data"] = metadata["rowwise_data"][:-self.pad_length] + if metadata.get("columnwise_data") is not None: + if isinstance(tensor, NVFP4TensorStorage): + # NVFP4 transposes columnwise and packs 2 values per byte + metadata["columnwise_data"] = metadata["columnwise_data"][ + ..., :-self.pad_length // 2 + ].contiguous() + else: + # MXFP8 columnwise is not transposed, strip first dim + metadata["columnwise_data"] = metadata["columnwise_data"][ + :-self.pad_length + ] + M = self._unsharded_shape[0] + if isinstance(tensor, NVFP4TensorStorage): + # NVFP4 scale_inv shapes (see NVFP4Quantizer.get_scale_shape): + # rowwise_scale_inv: [round_up(M, 128), round_up(ceil(K/16), 4)] + # columnwise_scale_inv: [round_up(K, 128), round_up(ceil(M/16), 4)] + # ETP shards M (dim 0 of the weight), so strip to the unpadded sizes. + if metadata.get("rowwise_scale_inv") is not None: + m_rows = round_up_to_nearest_multiple(M, 128) + metadata["rowwise_scale_inv"] = metadata["rowwise_scale_inv"][:m_rows] + if metadata.get("columnwise_scale_inv") is not None: + m_tiles = round_up_to_nearest_multiple( + math.ceil(M / NVFP4_BLOCK_SCALING_SIZE), 4 + ) + metadata["columnwise_scale_inv"] = ( + metadata["columnwise_scale_inv"][:, :m_tiles].contiguous() + ) + else: + # MXFP8 scale_inv shapes (see MXFP8Quantizer.get_scale_shape): + # rowwise_scale_inv: [round_up(M, 128), round_up(K//32, 4)] + # columnwise_scale_inv: [round_up(M//32, 4), round_up(K, 128)] + # ETP shards M (dim 0 of the weight), so strip to the unpadded sizes. + if metadata.get("rowwise_scale_inv") is not None: + m_rows = round_up_to_nearest_multiple(M, 128) + metadata["rowwise_scale_inv"] = metadata["rowwise_scale_inv"][:m_rows] + if metadata.get("columnwise_scale_inv") is not None: + m_tiles = round_up_to_nearest_multiple( + M // MXFP8_BLOCK_SCALING_SIZE, 4 + ) + metadata["columnwise_scale_inv"] = ( + metadata["columnwise_scale_inv"][:m_tiles] + ) + + return type(tensor)(**metadata, shape=self._unsharded_shape, dtype=torch.bfloat16) + else: + return tensor[:-self.pad_length] + + def _all_gather_weight(self, async_op, skip_weight_cast, cast_noop_flag, fwd, nvtx_label=None): + """Quantize (if needed) and all-gather weight. Returns (weight_total, handle).""" + if nvtx_label is None: + nvtx_label = ( + self._debug_name + + (".fwd" if fwd else ".bwd") + + (".async" if async_op else ".sync") + ) + + weights = self._weights + + # 1. Transition state for async gathers. + if async_op: + for w in weights: + w._set_state(ETPWeightState.ASYNC_WAIT) + else: + for w in weights: + w._set_state(ETPWeightState.DATA_READY_SYNC) + + # 2. Prepare: quantize, set usage direction. + # Static eligibility (quantizer class, flags, amax group) is fixed + # after model construction — compute once and cache on self so the + # hot path only pays the cheap per-call skip_weight_cast check. + if ETP_CONFIG.coalesce_amax_allreduce: + static_ok = getattr(self, "_coalesced_amax_static", None) + if static_ok is None: + static_ok = _coalesced_amax_static_eligible(weights) + self._coalesced_amax_static = static_ok + # Per-call: match the skip_weight_cast gate in _quantize_if_needed + # (fire when either skip_weight_cast is False or cast_noop_flag + # was provided by the FP8/NVFP4 recipe). + use_coalesced = static_ok and not ( + skip_weight_cast is True and cast_noop_flag is None + ) + else: + use_coalesced = False + + if _etp_dbg_should_log(self._debug_name, 'ag_decision'): + qtypes = [type(w._quantizer).__name__ for w in weights[:3]] + print_rank_0( + f"[ETP_DEBUG] AG {self._debug_name} fwd={fwd} chain={self.chain_id} " + f"coalesced={use_coalesced} skip_cast={skip_weight_cast} " + f"noop={'set' if cast_noop_flag is not None else 'None'} " + f"coalesce_cfg={ETP_CONFIG.coalesce_amax_allreduce} " + f"static_ok={getattr(self, '_coalesced_amax_static', 'N/A')} " + f"qtypes={qtypes}{'...' if len(weights)>3 else ''}" + ) + + if use_coalesced: + _quantize_with_coalesced_amax(weights, skip_weight_cast, cast_noop_flag) + else: + for w in weights: + w._quantize_if_needed(skip_weight_cast, cast_noop_flag) + for w in weights: + if w.did_cast_to_low_precision: + w._quantizer.set_usage(rowwise=fwd, columnwise=not fwd) + + if _etp_dbg_should_log(self._debug_name, f'ag_numerics_{"fwd" if fwd else "bwd"}'): + lines = [f"[ETP_DEBUG] post-quantize {self._debug_name} fwd={fwd} " + f"usage=row:{fwd},col:{not fwd}"] + for i, w in enumerate(weights[:3]): + lines.append(f" w[{i}] shard: {_etp_dbg_tensor(f'{w._debug_name}', w.data)}") + if w.did_cast_to_low_precision: + lines.append(_etp_dbg_quantized(f' w[{i}] quantized', w.quantized)) + print_rank_0("\n".join(lines)) + + # 3. Build gather inputs. + # quantizers / dtypes / etp_group are stable after model construction — + # cache on the anchor (self == weights[0]) to avoid rebuilding lists + # every call. w.quantized is NOT cached because it can rebind. + quantizers = self._cached_quantizers + if quantizers is None: + quantizers = [w._quantizer for w in weights] + self._cached_quantizers = quantizers + if weights[0].did_cast_to_low_precision: + gather_weights = [w.quantized for w in weights] + else: + gather_weights = list(w.get_padded_shard() for w in weights) + + # 4. Cache checkout — use pooled buffers for both async and sync gathers + # to avoid allocating fresh memory each iteration. + dtypes = self._cached_dtypes + if dtypes is None: + dtypes = [q.dtype if q is not None else w.dtype for q, w in zip(quantizers, weights)] + self._cached_dtypes = dtypes + out_buffers = [] + cache = get_global_ETP_cache() + for p, dt in zip(weights, dtypes): + if fwd: + if p._ag_ticket_fwd is None: + p._ag_ticket_fwd = cache.reserve(p, dt, fwd=True) + cache.get(p._ag_ticket_fwd) + cache.release(p._ag_ticket_fwd) + out_buffers.append(cache.get(p._ag_ticket_fwd)) + else: + if p._ag_ticket_bwd is None: + p._ag_ticket_bwd = cache.reserve(p, dt, fwd=False) + out_buffers.append(cache.get(p._ag_ticket_bwd)) + + # 5. Communicate. + etp_group = self._cached_etp_group + if etp_group is None: + etp_group = weights[0].group + self._cached_etp_group = etp_group + if ETP_CONFIG.check_param_states and len(gather_weights) > 1: + # Debug invariant: batched AG needs distinct output buffers per expert. + assert len(set(id(b) for b in out_buffers)) == len(out_buffers), \ + "Duplicate output buffers in batched all-gather — experts need distinct cache keys" + + # ASYNC AG: wrap issue on ag_stream so both issue (NCCL preEvent) and + # wait land on the same stream — ag_stream's tail then reflects the + # collective's full lifecycle, which is what external + # wait_stream(ag_stream) drains depend on. Explicit outer→ag_stream + # event preserves the quantize writer edge (a bare stream context + # would drop it). + # SYNC AG: stay on caller — output ready on return. + if async_op: + outer_stream = torch.cuda.current_stream() + ag_stream = get_ag_stream(self.chain_id, etp_group) + outer_sync_event = torch.cuda.Event() + outer_sync_event.record(outer_stream) + ag_stream.wait_event(outer_sync_event) + ag_ctx = torch.cuda.stream(ag_stream) + else: + ag_ctx = nullcontext() + + with ag_ctx: + if len(gather_weights) > 1: + nvtx_range_push(f"{nvtx_label}.batched_etp_ag") + results, handle = grouped_gather_along_first_dim( + gather_weights, etp_group, + async_op=async_op, + quantizers=quantizers, + output_tensors=out_buffers, + ) + nvtx_range_pop(f"{nvtx_label}.batched_etp_ag") + else: + nvtx_range_push(f"{nvtx_label}.etp_ag") + weight_total, handle = gather_along_first_dim( + gather_weights[0], etp_group, + quantizer=quantizers[0], + async_op=async_op, + output_tensor=out_buffers[0] if out_buffers is not None else None, + ) + nvtx_range_pop(f"{nvtx_label}.etp_ag") + results = [weight_total] + + result = results if self.is_routed_expert else results[0] + + # 6. Wrap handle. + if async_op: + handle = ETPShardHandle(handle, weights) + else: + handle = None + + return result, handle + + def _wait_param_gather(self): + # Enter ag_stream context so handle.wait() + ag_event.record() both + # land on ag_stream. That makes ag_event mark ag_stream's tail, which + # is what external drains (wait_stream(ag_stream) in finalize_model_grads + # and cuda_graphs._wait_side_streams) actually block on. + ag_stream = self._cached_ag_stream + if ag_stream is None: + ag_stream = get_ag_stream(self.chain_id, self.group) + self._cached_ag_stream = ag_stream + with torch.cuda.stream(ag_stream): + if self._prefetch_handle is not None: + self._prefetch_handle.wait() + self._prefetch_handle = None + self.ag_event.record() + + def _all_gather_weight_on_demand(self, fwd, skip_weight_cast=False, cast_noop_flag=None): + result, _ = self._all_gather_weight( + async_op=False, + skip_weight_cast=skip_weight_cast, + cast_noop_flag=cast_noop_flag, + fwd=fwd, + ) + result = result if self.is_routed_expert else [result] + result = [self._strip_padding(r) for r in result] + result = [r.detach().requires_grad_(w.requires_grad) for r, w in zip(result,self._weights)] + return result if self.is_routed_expert else result[0] + + def _get_prefetched_weight(self, fwd, skip_weight_cast=False, cast_noop_flag=None): + # Stale-read guard: state must reflect an AG issued for this cycle; + # otherwise cache.get() would return the prior iter's AG buffer. + if ETP_CONFIG.check_param_states: + for w in self._weights: + assert w.state in ( + ETPWeightState.ASYNC_WAIT, + ETPWeightState.DATA_READY, + ETPWeightState.DATA_READY_SYNC, + ), ( + f"[ETP] _get_prefetched_weight({'fwd' if fwd else 'bwd'}) on " + f"{self._debug_name} with state={w.state!r} — no AG issued; " + f"cache.get() would return stale data. Check the chain's " + f"_need_weight_prefetch flag and issuer's prefetch logic." + ) + _was_drained = getattr(self, '_already_ag_drained', False) + if _was_drained: + # Producer already drained via wait_async_comms; skip the captured + # cross-graph wait (CUDA no-op anyway). Correctness is provided by + # the eager main_stream sync chain in the surrounding training loop. + self._already_ag_drained = False + else: + # Intra-graph or eager consume: drain inline. + self._wait_param_gather() + self.ag_event.wait() + + # Retrieve prefetched results from cache + result = [] + cache = get_global_ETP_cache() + for w in self._weights: + ticket = w._ag_ticket_fwd if fwd else w._ag_ticket_bwd + result.append(cache.get(ticket)) + + result = [self._strip_padding(r) for r in result] + + if _etp_dbg_should_log(self._debug_name, f'prefetch_{"fwd" if fwd else "bwd"}'): + lines = [f"[ETP_DEBUG] prefetched {self._debug_name} fwd={fwd} " + f"already_drained={_was_drained}"] + for i, r in enumerate(result[:3]): + if isinstance(r, (NVFP4TensorStorage, MXFP8TensorStorage)): + lines.append(_etp_dbg_quantized(f' gathered[{i}]', r)) + else: + lines.append(f" gathered[{i}]: {_etp_dbg_tensor('', r)}") + print_rank_0("\n".join(lines)) + + result = [r.detach().requires_grad_(w.requires_grad) for r, w in zip(result, self._weights)] + return result if self.is_routed_expert else result[0] + + def all_gather_and_prefetch_bwd(self, nvtx_label=None): + """ + Backward variant: get current weight (from cache if prefetched, else + sync gather) and async-prefetch prev_w. + + Safe thanks to the coat-check cache: get() returns the current buffer + to the pool, and the prefetch's checkout() will allocate a separate + buffer if the pool is empty (i.e. the current buffer is still live + via the caller's tensor reference). + + Returns: + weight_total + """ + + if ETP_CONFIG.weight_prefetch and self.next_w is not None: + result = self._get_prefetched_weight(False, skip_weight_cast=True) + else: + result = self._all_gather_weight_on_demand(False, skip_weight_cast=True) + + if ( + ETP_CONFIG.weight_prefetch + and self.prev_w is not None + and self.prev_w._need_weight_prefetch + and self.prev_w._need_weight_prefetch_bwd + ): + # Pre-AG work (quantize, ticket lookup) runs on caller's stream; + # the NCCL collective itself is wrapped on ag_stream inside + # _all_gather_weight (see the async/sync gate there for rationale). + _, handle = self.prev_w._all_gather_weight( + async_op=True, skip_weight_cast=True, cast_noop_flag=None, + fwd=False, nvtx_label=nvtx_label, + ) + self.prev_w._prefetch_handle = handle + + # The unsharded tensor has been returned, no pending work so reset state to NONE + for w in self._weights: + w._set_state(ETPWeightState.NONE) + + if ETP_CONFIG.weight_prefetch and self.next_w is not None: + cache = get_global_ETP_cache() + for w in self._weights: + cache.release(w._ag_ticket_bwd) + + return result + + def batched_all_gather_and_prefetch_bwd(self, nvtx_label=None): + """Batched backward all-gather + prefetch. Wrapper around all_gather_and_prefetch_bwd.""" + assert self.is_routed_expert and self.weight_list is not None + return self.all_gather_and_prefetch_bwd(nvtx_label=nvtx_label) + + def all_gather_and_prefetch( + self, + fwd: bool = True, + skip_weight_cast: bool = False, + cast_noop_flag: torch.Tensor = None, + nvtx_label: str = None, + ): + """ + All-gather current weight and async-prefetch the next weight. + + Returns: + weight_total + """ + if ETP_CONFIG.weight_prefetch and self.prev_w is not None: + result = self._get_prefetched_weight(True, skip_weight_cast, cast_noop_flag) + else: + result = self._all_gather_weight_on_demand(True, skip_weight_cast, cast_noop_flag) + + # Prefetch next weight + if ( + ETP_CONFIG.weight_prefetch + and self.next_w is not None + and self.next_w._need_weight_prefetch + ): + # Pre-AG work on caller; NCCL wrap lives at the collective site + # inside _all_gather_weight. See all_gather_and_prefetch_bwd. + _, handle = self.next_w._all_gather_weight( + async_op=True, + skip_weight_cast=skip_weight_cast, + cast_noop_flag=cast_noop_flag, + fwd=fwd, nvtx_label=nvtx_label, + ) + self.next_w._prefetch_handle = handle + + # The unsharded tensor has been returned, no pending work so reset state to NONE + for w in self._weights: + w._set_state(ETPWeightState.NONE) + + # Lazy population of linked list: link previous weight to current weight + # Uses per-chain state so dense and expert chains never cross-link. + cls = type(self) + chain = cls._get_chain_state(self.chain_id) + if not self.prefetch_initialized: + last_w = chain['last_weight'] + if last_w is not None and last_w.next_w is None: + cls._buffer_link_table_row(last_w, self, chain) + last_w.next_w = self + self.prev_w = last_w + + cache = get_global_ETP_cache() + + # Set the fwd ag buffer + quantizers = [w._quantizer for w in self._weights] + dtypes = [q.dtype if q is not None else w.dtype for q, w in zip(quantizers, self._weights)] + for w, dt in zip(self._weights, dtypes): + w._ag_ticket_fwd = cache.reserve(w, dt, fwd=True) + cache.get(w._ag_ticket_fwd) + cache.release(w._ag_ticket_fwd) + + self.prefetch_initialized = True + chain['last_weight'] = self + elif not chain['link_table_flushed'] and chain['link_table_buffer']: + # Second forward pass: flush the complete table atomically to avoid interleaving + chain['link_table_flushed'] = True + print_rank_0("\n".join(chain['link_table_buffer']) + "\n") + + return result + + def batched_all_gather_and_prefetch(self, **kwargs): + """Batched all-gather + prefetch for expert weights. Wrapper around all_gather_and_prefetch.""" + assert self.is_routed_expert and self.weight_list is not None + return self.all_gather_and_prefetch(**kwargs) + + def get_wgrad_tensor(self): + return _wgrad_pool_get(self._unsharded_shape, self.main_grad.dtype, self.device) + + def register_grad_accum_hook(self, grad_accum_node, hook): + """Register a DDP backward hook to be called from _finalize_wgrad. + + For ETP params, autograd may receive None (async RS) so the normal grad + accumulator hook never fires. Instead, _finalize_wgrad calls the hook + explicitly after RS wait + gradient accumulation, ensuring DDP's + register_grad_ready fires at exactly the right time. + """ + self._grad_accum_node = grad_accum_node + self._grad_accum_hook = hook + + @staticmethod + def _handle_megatron_grad_accum(param): + """Handle megatron DDP and gradient accumulation fusion. + + Do NOT set param.grad before calling the hook — the hook checks + param.grad and would accumulate it into main_grad if zero_out_wgrad + is True, corrupting the gradient with a non-zero dummy. + """ + if hasattr(param, "grad_added_to_main_grad"): + param.grad_added_to_main_grad = True + dummy_grad = get_dummy_wgrad(list(param.main_grad.shape), param.dtype) + if getattr(param, '_grad_accum_hook', None) is not None: + param._grad_accum_hook() + + param._set_rs_state(ETPWeightState.NONE) + return dummy_grad + + + def _wait_reduce_scatter(self, finalize_grad=False): + # Enter rs_stream context so handle.wait() + rs_event.record() land + # on rs_stream — mirrors _wait_param_gather for the RS path. + # When finalize_grad=True, main_grad.add_ also runs on rs_stream + # (right after NCCL RS), so it starts during AG drain rather than + # after it — avoids SM-saturation blocking cross-graph overlap. + rs_stream = self._cached_rs_stream + if rs_stream is None: + rs_stream = get_rs_stream(self.chain_id, self.group) + self._cached_rs_stream = rs_stream + with torch.cuda.stream(rs_stream): + if self._wgrad_rs_handle is not None: + self._wgrad_rs_handle.wait() + self._wgrad_rs_handle = None + self.rs_event.record() + if finalize_grad: + cache = get_global_ETP_cache() + for w in self._weights: + w._set_rs_state(ETPWeightState.NONE) + wgrad_rs = cache.get(w._rs_ticket) + w.main_grad.add_(wgrad_rs) + if ETP_CONFIG.debug_numerics > 0 and not _etp_dbg_capturing(): + if bool(torch.isinf(w.main_grad).any()) or bool(torch.isnan(w.main_grad).any()): + print_rank_0( + f"[ETP_DEBUG] *** main_grad ANOMALY after finalize_grad RS *** " + f"{w._debug_name}: {_etp_dbg_tensor('main_grad', w.main_grad)}" + ) + cache.release(w._rs_ticket) + if hasattr(w, "grad_added_to_main_grad"): + w.grad_added_to_main_grad = True + self._already_finalized = True + # Release stashed wgrad inputs: UNGRAPHED buffers go back to the pool; + # GRAPHED just drops Python refs (addresses must stay stable for CG). + if getattr(self, '_wgrad_input_bufs', None) is not None: + if self.chain_id == ETPChain.UNGRAPHED.value: + for buf in self._wgrad_input_bufs: + _wgrad_pool_put(buf) + self._wgrad_input_bufs = None + + def _reduce_scatter(self, wgrads, async_op, nvtx_label=None): + """Reduce-scatter one or more wgrads. Returns (outputs, handle). + + Single tensor: plain reduce-scatter (no coalescing). + Multiple tensors: coalesced reduce-scatter. + """ + if nvtx_label is None: + nvtx_label = ( + self._debug_name + + ".bwd" + + (".async" if async_op else ".sync") + ) + + for w in self._weights: + if async_op: + w._set_rs_state(ETPWeightState.ASYNC_WAIT) + else: + w._set_rs_state(ETPWeightState.DATA_READY_SYNC) + + if self.pad_length > 0: + wgrads = [torch.nn.functional.pad(w, (0, 0, 0, self.pad_length)) for w in wgrads] + + if async_op: + dtypes = [w.dtype for w in wgrads] + out_buffers = [] + cache = get_global_ETP_cache() + for p, dt in zip(self._weights, dtypes): + if p._rs_ticket is None: + p._rs_ticket = cache.reserve(p, dt, fwd=False, reduce_scatter=True) + out_buffers.append(cache.get(p._rs_ticket)) + else: + out_buffers = [None] * len(wgrads) + + # ASYNC RS: wrap issue on rs_stream — issue and wait on the same stream + # means rs_stream's tail reflects the full NCCL lifecycle, what + # external wait_stream(rs_stream) drains depend on. Explicit outer→ + # rs_stream event preserves the wgrad-GEMM writer edge. Mirrors AG. + # SYNC RS: stay on caller — same constraint as sync AG. + if async_op: + outer_stream = torch.cuda.current_stream() + rs_stream = get_rs_stream(self.chain_id, self.group) + outer_sync_event = torch.cuda.Event() + outer_sync_event.record(outer_stream) + rs_stream.wait_event(outer_sync_event) + rs_ctx = torch.cuda.stream(rs_stream) + else: + rs_ctx = nullcontext() + + with rs_ctx: + if len(wgrads) == 1: + nvtx_range_push(f"{nvtx_label}.etp_rs") + out, handle = reduce_scatter_along_first_dim( + wgrads[0], self.group, async_op=async_op, output=out_buffers[0] + ) + nvtx_range_pop(f"{nvtx_label}.etp_rs") + return [out], handle + else: + outputs = [] + nvtx_range_push(f"{nvtx_label}.batched_etp_rs") + with torch.distributed._coalescing_manager( + group=self.group, + device=wgrads[0].device, + async_ops=async_op, + ) as cm: + for out_buffer, tensor in zip(out_buffers, wgrads): + out, _ = reduce_scatter_along_first_dim(tensor, self.group, output=out_buffer) + outputs.append(out) + nvtx_range_pop(f"{nvtx_label}.batched_etp_rs") + + return outputs, cm if async_op else None + + def wgrad_reduce_scatter(self, wgrad, nvtx_label=None): + """Reduce-scatter wgrad(s). Sync for last weight, async+deferred for others. + + Accepts a single tensor (non-routed) or list of tensors (routed experts). + + Returns: + Single tensor or list for sync (last weight) — backward should return this. + None or tuple of Nones for async — backward should return this. + """ + batched = isinstance(wgrad, (list, tuple)) + wgrads = list(wgrad) if batched else [wgrad] + weights = self._weights + + # UNGRAPHED-chain wgrads are recycled via the standalone pool (_wgrad_pool_put). + # GRAPHED-chain wgrads cannot pool-recycle because CUDA graphs require + # stable buffer addresses across replay. + poolable = self.chain_id == ETPChain.UNGRAPHED.value + + if _etp_dbg_should_log(self._debug_name, 'rs_input'): + lines = [f"[ETP_DEBUG] RS input {self._debug_name} " + f"async={self.prev_w is not None and ETP_CONFIG.weight_prefetch}"] + for i, g in enumerate(wgrads[:3]): + lines.append(f" wgrad[{i}]: {_etp_dbg_tensor('', g)}") + print_rank_0("\n".join(lines)) + + if ETP_CONFIG.weight_prefetch and self.prev_w is not None: + # Async reduce-scatter (not last weight — deferred finish). Pre-RS + # work on caller; NCCL wrap lives at the collective site inside + # _reduce_scatter (mirrors the AG prefetch sites). + _, rs_handle = self._reduce_scatter(wgrads, async_op=True, nvtx_label=nvtx_label) + self._wgrad_rs_handle = ETPShardHandle(rs_handle, weights, reduce_scatter=True) + # Stash wgrad input buffers — cannot recycle yet because the async RS + # kernel is still reading them on rs_stream. + self._wgrad_input_bufs = wgrads + ret = tuple([None] * len(wgrads)) if batched else None + else: + # Sync reduce-scatter (last weight in chain) — RS done, recycle immediately + wgrads, _ = self._reduce_scatter(wgrads, async_op=False, nvtx_label=nvtx_label) + torch._foreach_add_([p.main_grad for p in weights], wgrads) + if ETP_CONFIG.debug_numerics > 0 and not _etp_dbg_capturing(): + for w in weights[:3]: + if bool(torch.isinf(w.main_grad).any()) or bool(torch.isnan(w.main_grad).any()): + print_rank_0( + f"[ETP_DEBUG] *** main_grad ANOMALY after sync RS *** " + f"{w._debug_name}: {_etp_dbg_tensor('main_grad', w.main_grad)}" + ) + result = [self._handle_megatron_grad_accum(p) for p in weights] + + if poolable: + for buf in wgrads: + _wgrad_pool_put(buf) + ret = result if batched else result[0] + + # Wait for last reduce scatter if it was async + # Currently only support reduce scattering in reverse order + if ETP_CONFIG.weight_prefetch and self.next_w is not None: + self.next_w._wait_reduce_scatter() + + if getattr(self.next_w, '_already_finalized', False): + self.next_w._already_finalized = False + else: + self.next_w.rs_event.wait() + cache = get_global_ETP_cache() + next_weights = self.next_w._weights + wgrads = [cache.get(w._rs_ticket) for w in next_weights] + torch._foreach_add_([w.main_grad for w in next_weights], wgrads) + if ETP_CONFIG.debug_numerics > 0 and not _etp_dbg_capturing(): + for w in next_weights[:3]: + if bool(torch.isinf(w.main_grad).any()) or bool(torch.isnan(w.main_grad).any()): + print_rank_0( + f"[ETP_DEBUG] *** main_grad ANOMALY after async RS finalize *** " + f"{w._debug_name}: {_etp_dbg_tensor('main_grad', w.main_grad)}" + ) + for w in next_weights: + self._handle_megatron_grad_accum(w) + cache.release(w._rs_ticket) + + return ret + + def batched_wgrad_reduce_scatter(self, wgrad_list, nvtx_label=None): + """Batched version of wgrad_reduce_scatter.""" + assert self.is_routed_expert and self.weight_list is not None + return self.wgrad_reduce_scatter(wgrad_list, nvtx_label=nvtx_label) + + def __torch_function__(self, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + + if func is torch.Tensor.detach: + with torch._C.DisableTorchFunctionSubclass(): + # Perform the raw detach + result = func(*args, **kwargs) + # Re-wrap it in your subclass so PyTorch is happy + return result.as_subclass(type(self)) + + # 2. For everything else (add, mul, etc.), be transparent/decay. + with torch._C.DisableTorchFunctionSubclass(): + return func(*args, **kwargs) + + +def print_rank_0(message, rank=None): + """If distributed is initialized or rank is specified, print only on rank 0.""" + if rank is not None: + if rank == 0: + print(message, flush=True) + elif torch.distributed.is_initialized(): + if torch.distributed.get_rank() == 0: + print(message, flush=True) + else: + print(message, flush=True) + +@dataclass +class _TicketSlot: + """Internal slot backing a persistent ticket in the ETP buffer cache.""" + key: tuple # cache key (shape, dtype, ...) + param: 'ETPShardedParam' # for lazy allocation metadata + dtype: object # torch.dtype or tex.DType + reduce_scatter: bool + fwd: bool + chain_id: str = ETPChain.GRAPHED.value # chain this slot belongs to + buf: Optional[torch.Tensor] = field(default=None) # None when released or after clear() + + +class ETPWeightCache: + """ + Ticket-based buffer pool for ETP all-gather / reduce-scatter buffers. + + - ``reserve(param, dtype, fwd)`` → ``ticket`` + Assigns a persistent ticket (no buffer allocated yet). + - ``get(ticket)`` → ``buffer`` + Returns the buffer, lazily allocating from pool or fresh if needed. + - ``release(ticket)`` + Returns the buffer to the pool. Ticket remains valid; next ``get()`` + will re-allocate from the pool. + - ``clear()`` + Drops all buffers and pools. Tickets remain valid; next ``get()`` + lazily allocates fresh buffers. + """ + + # Bytes per element for known dtypes (used for logging). + _BYTES_PER_ELEMENT = { + torch.bfloat16: 2, + torch.float16: 2, + torch.float32: 4, + tex.DType.kFloat4E2M1: 0.5, + tex.DType.kFloat8E4M3: 1, + } + + def __init__(self): + self._pool: Dict[tuple, List[torch.Tensor]] = defaultdict(list) + self._slots: Dict[int, _TicketSlot] = {} + self._next_ticket: int = 0 + self._total_bytes: int = 0 # running total of allocated bytes + self.key_to_allocate_func = {} + + @staticmethod + def _buf_bytes(shape, dtype) -> int: + """Estimate buffer size in bytes.""" + numel = 1 + for d in shape: + numel *= d + bpe = ETPWeightCache._BYTES_PER_ELEMENT.get(dtype, None) + return numel * bpe + + def _allocate_buffer(self, param: 'ETPShardedParam', dtype, reduce_scatter, fwd) -> torch.Tensor: + if reduce_scatter: + out_shape = param._sharded_padded_shape + else: + out_shape = param._unsharded_shape_padded + + if not isinstance(dtype, torch.dtype): + quantizer = param._quantizer + assert quantizer is not None + param._quantizer.set_usage(rowwise=fwd, columnwise=not fwd) + + buf = param._quantizer.make_empty( + out_shape, + dtype=torch.bfloat16, + device=torch.cuda.current_device(), + ) + else: + buf = torch.empty( + out_shape, dtype=dtype, device=param.device, memory_format=torch.contiguous_format + ) + + buf_bytes = self._buf_bytes(out_shape, dtype) + self._total_bytes += buf_bytes + print_rank_0( + f"[ETP Cache] +{buf_bytes / 1024**2:.1f} MB (shape={out_shape}, dtype={dtype}) " + f"total={self._total_bytes / 1024**2:.1f} MB id: {id(buf)} fwd: {fwd}" + ) + return buf + + def reserve(self, param: 'ETPShardedParam', dtype, fwd: bool, reduce_scatter=False) -> int: + """Assign a persistent ticket. No buffer is allocated until ``get()``.""" + key = param._get_cache_key(dtype, fwd, reduce_scatter) + ticket = self._next_ticket + self._next_ticket += 1 + + self._slots[ticket] = _TicketSlot( + key=key, param=param, dtype=dtype, reduce_scatter=reduce_scatter, fwd=fwd, + chain_id=getattr(param, 'chain_id', ETPChain.UNGRAPHED.value), + ) + return ticket + + def get(self, ticket: int) -> torch.Tensor: + """Return the buffer for *ticket*, lazily allocating if needed.""" + slot = self._slots[ticket] + if slot.buf is None: + pool = self._pool[slot.key] + slot.buf = pool.pop() if pool else self._allocate_buffer( + slot.param, slot.dtype, slot.reduce_scatter, fwd=slot.fwd + ) + self.key_to_allocate_func[slot.key] = (slot.param, slot.dtype, slot.reduce_scatter, slot.fwd) + + return slot.buf + + def release(self, ticket: int): + """Return the buffer to the pool. Ticket remains valid. + + slot.buf is intentionally NOT cleared: get() must stay idempotent so that + CUDA-graph-captured buffers keep their fixed address across replays, and + reallocate_to_mempool() can find every dense-chain buffer. + """ + slot = self._slots[ticket] + if slot.buf is None: + return + # Use identity check — tensor == tensor returns a multi-element bool tensor + # which crashes in a boolean context ("Boolean value of Tensor is ambiguous"). + if not any(b is slot.buf for b in self._pool.get(slot.key, [])): + self._pool[slot.key].append(slot.buf) + + def clear(self): + """Drop all buffers; tickets remain valid and lazily re-allocate on next get().""" + for slot in self._slots.values(): + slot.buf = None + self._pool.clear() + self._total_bytes = 0 + + def reallocate_to_mempool(self, device, mempool): + """Re-allocate GRAPHED-chain ticket buffers into a CUDA graph memory pool. + + Call BEFORE graph capture so every GRAPHED-chain buffer lives in the capture + pool and no allocations are recorded inside the graph. UNGRAPHED-chain + buffers are left in regular memory (they are never referenced by any + captured graph). + """ + + # Identify keys that belong to the GRAPHED chain + graphed_keys = set() + for slot in self._slots.values(): + if slot.chain_id == ETPChain.GRAPHED.value: + graphed_keys.add(slot.key) + + # Clone only GRAPHED-chain pool buffers into the passed in mempool + self._total_bytes = 0 + new_pool = defaultdict(list) + torch._C._cuda_beginAllocateCurrentThreadToPool(device, mempool) + for key, buffers in self._pool.items(): + if key not in graphed_keys: + continue + new_buffers = [] + for _ in range(len(buffers)): + buf = self._allocate_buffer(*self.key_to_allocate_func[key]) + new_buffers.append(buf) + new_pool[key] = new_buffers + torch._C._cuda_endAllocateToPool(device, mempool) + + # Map each buffer in the old pool to its corresponding new one (GRAPHED only) + old_to_new_buff = {} + for key, old_pool in self._pool.items(): + if key not in graphed_keys: + continue + new = new_pool[key] + for old_buf, new_buf in zip(old_pool, new): + old_to_new_buff[old_buf] = new_buf + + # Replace each GRAPHED slot's reference; keep UNGRAPHED slots unchanged + for slot in self._slots.values(): + if slot.chain_id == ETPChain.GRAPHED.value and slot.buf is not None and slot.buf in old_to_new_buff: + slot.buf = old_to_new_buff[slot.buf] + + # Merge: GRAPHED keys get new buffers, UNGRAPHED keys keep old ones + for key, buffers in self._pool.items(): + if key not in graphed_keys: + new_pool[key] = buffers + self._pool = new_pool + + # Remap quantized params into the CG mempool — but only for params on + # the GRAPHED chain. UNGRAPHED-chain params (embedding, output_layer, + # and MoE paths whose scope is not captured) run eagerly and don't + # need their quantized storage in the CG mempool. + torch._C._cuda_beginAllocateCurrentThreadToPool(device, mempool) + for w in _ETP_PARAMS: + if getattr(w, "chain_id", ETPChain.GRAPHED.value) != ETPChain.GRAPHED.value: + continue + if w.quantized is None: + continue + if isinstance(w.quantized, NVFP4TensorStorage): + w.quantized._rowwise_data = torch.clone(w.quantized._rowwise_data) + w.quantized._columnwise_data = torch.clone(w.quantized._columnwise_data) + w.quantized._rowwise_scale_inv = torch.clone(w.quantized._rowwise_scale_inv) + w.quantized._columnwise_scale_inv = torch.clone(w.quantized._columnwise_scale_inv) + w.quantized._amax_columnwise = torch.clone(w.quantized._amax_columnwise) + w.quantized._amax_rowwise = torch.clone(w.quantized._amax_rowwise) + elif isinstance(w.quantized, MXFP8TensorStorage): + w.quantized._rowwise_data = torch.clone(w.quantized._rowwise_data) + w.quantized._columnwise_data = torch.clone(w.quantized._columnwise_data) + w.quantized._rowwise_scale_inv = torch.clone(w.quantized._rowwise_scale_inv) + w.quantized._columnwise_scale_inv = torch.clone(w.quantized._columnwise_scale_inv) + else: + assert False + torch._C._cuda_endAllocateToPool(device, mempool) + + return + +def get_global_ETP_cache() -> ETPWeightCache: + """Get or lazily create the global cache instance.""" + global _ETP_CACHE + if _ETP_CACHE is None: + _ETP_CACHE = ETPWeightCache() + return _ETP_CACHE + + +def reallocate_etp_cache_to_mempool(device, mempool): + """Re-allocate all ETP cache buffers into a CUDA graph memory pool.""" + if _ETP_CACHE is not None: + _ETP_CACHE.reallocate_to_mempool(device, mempool) + + +def wait_async_comms(chain_id: str = None, skip_rs: bool = False, finalize_after_drain: bool = False): + """Drain in-flight ETP async AG / RS handles. + + When called inside CUDA graph capture, the drains are captured into that + graph. This is the producer-side hook for cross-graph AG/RS overlap: + captured cudaStreamWaitEvent on an event recorded in a different capture + session is a CUDA no-op, so consumer graphs can't safely wait on + cross-graph events. Instead, the producer drains here and flags the + param; the consumer reads the flag and skips its captured wait. + + Args: + chain_id: If specified, only drain params on this chain. + skip_rs: Drain AG only; leave RS in flight. + finalize_after_drain: After RS drain, also accumulate wgrad into + main_grad. Runs main_grad.add_ on rs_stream (right after + NCCL RS) so it starts during AG drain rather than after, + avoiding SM-saturation that blocks cross-graph overlap. + Falls back to caller-stream _finalize_wgrad if no RS handle. + + Per-param side effects: + * _already_ag_drained = True (if an AG handle was drained) + * _already_finalized = True (if finalize_after_drain=True) + """ + for param in list(_inflight_comm_params): + if chain_id is not None and getattr(param, 'chain_id', ETPChain.UNGRAPHED.value) != chain_id: + continue + had_ag = param._prefetch_handle is not None + param._wait_param_gather() + if had_ag: + param._already_ag_drained = True + if not skip_rs: + param._wait_reduce_scatter(finalize_grad=finalize_after_drain) + if finalize_after_drain and not getattr(param, '_already_finalized', False): + cache = get_global_ETP_cache() + param.rs_event.wait() + for w in param._weights: + ETPShardedParam._finalize_wgrad(w, cache.get(w._rs_ticket)) + cache.release(w._rs_ticket) + param._already_finalized = True + + +@dataclass +class BatchedNVFP4AllGatherAsyncHandle: + """Handle for batched asynchronous NVFP4 all-gathers.""" + output_handles: List[_NVFP4AllGatherAsyncHandle] + outer_async_handle: torch.distributed.Work + _synchronized: bool = False + + def wait(self) -> None: + """Wait for the async operation to complete and post-process the tensor.""" + if self._synchronized: + return + self.outer_async_handle.wait() + # Fixes interleaved data for transposed tensor/scale inv and pads scale inv if needed. + for output_handle in self.output_handles: + if output_handle is not None: + assert output_handle.async_handle is None + output_handle.post_process_nvfp4_gather() + # release any tensor references just in case + output_handle.output = None + output_handle.columnwise_data_interleaved = None + output_handle.columnwise_scale_inv_interleaved = None + + self._synchronized = True + + +def grouped_gather_along_first_dim( + weights: list, + process_group, + async_op: bool = False, + quantizers: list = None, + output_tensors: list = None, +): + """ + All-gather multiple weights in a single coalesced operation. + + Handles NVFP4 post-processing for both sync and async paths. + """ + # Determine device from first weight. + inp = weights[0] + if isinstance(inp, NVFP4TensorStorage): + device = ( + inp._rowwise_data.device if inp._rowwise_data is not None + else inp._columnwise_data.device + ) + else: + device = inp.device + + weights_all = [] + weight_handles = [] + with torch.distributed._coalescing_manager( + group=process_group, device=device, async_ops=async_op, + ) as gather_coalescing_manager: + for i, weight in enumerate(weights): + weight_all, weight_handle = gather_along_first_dim( + weight, process_group, + quantizer=quantizers[i], + output_tensor=output_tensors[i] if output_tensors is not None else None, + grouped=True, + ) + weights_all.append(weight_all) + weight_handles.append(weight_handle) + + if async_op: + handle = gather_coalescing_manager + has_nvfp4_handles = any( + isinstance(wh, _NVFP4AllGatherAsyncHandle) for wh in weight_handles + ) + if has_nvfp4_handles: + handle = BatchedNVFP4AllGatherAsyncHandle(weight_handles, handle) + else: + for wh in weight_handles: + if isinstance(wh, _NVFP4AllGatherAsyncHandle): + wh.post_process_nvfp4_gather() + handle = None + + return weights_all, handle diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index f3e7b57cf1..fca5b4ee61 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -5,6 +5,7 @@ """GroupedLinear API""" from typing import Union, Optional, Callable, Tuple, List from itertools import chain +import traceback import warnings import functools @@ -22,6 +23,7 @@ _2X_ACC_WGRAD, ) from ._common import WeightGradStore +from .extended_tensor_parallelism import wrap_module_params_etp from ..quantization import FP8GlobalStateManager from ..utils import ( divide, @@ -96,6 +98,7 @@ def forward( skip_fp8_weight_update, save_original_input, debug, + etp_size, ) = non_tensor_args num_gemms = len(m_splits) @@ -104,6 +107,14 @@ def forward( device = inp.device weight_requires_grad = weights[0].requires_grad + if etp_size > 1: + weights_etp_sharded = weights + weights = weights[0].batched_all_gather_and_prefetch( + fwd=True, + skip_weight_cast=is_first_microbatch is False, + cast_noop_flag=skip_fp8_weight_update, + ) + # Configure quantizers if save_original_input and isinstance(input_quantizers[0], Float8Quantizer): raise ValueError("DelayedScaling recipe is not supported with save_original_input") @@ -257,12 +268,20 @@ def forward( for weight in weights: ctx.weight_objects.append(weight) - tensors_to_save, tensor_objects = prepare_for_saving( - *inputmats, - *weights_fp8, - *weights, - *biases, - ) + if etp_size == 1: + tensors_to_save, tensor_objects = prepare_for_saving( + *inputmats, + *weights_fp8, + *weights, + *biases, + ) + else: + tensors_to_save, tensor_objects = prepare_for_saving( + *inputmats, + *weights_etp_sharded, + *biases, + ) + ctx.save_for_backward(*tensors_to_save) ctx.tensor_objects = tensor_objects @@ -278,6 +297,8 @@ def forward( if hasattr(weights[0], "__fsdp_param__"): # MCore FSDP creates main_grad lazily before backward ctx.main_grad_funcs = [weights[i].get_main_grad for i in range(num_gemms)] + elif etp_size > 1: + ctx.main_grad_funcs = [weights_etp_sharded[i].get_wgrad_tensor for i in range(num_gemms)] else: ctx.main_grad_funcs = [ lambda j=i: weights[j].main_grad for i in range(num_gemms) @@ -308,6 +329,7 @@ def forward( ctx.debug = debug ctx.save_original_input = save_original_input ctx.input_quantizers = input_quantizers + ctx.etp_size = etp_size # [*, in_features] -> [*, out_features] except first dimension changes for SP return out.view(-1, *inp.shape[1:-1], out.shape[-1]) @@ -318,10 +340,15 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], with get_nvtx_range_context("_GroupedLinear_backward"): saved_tensors = restore_from_saved(ctx.tensor_objects, ctx.saved_tensors) N = ctx.num_gemms - inputmats = saved_tensors[:N] - weights = saved_tensors[N : 2 * N] - origin_weights = saved_tensors[2 * N : 3 * N] - biases = saved_tensors[3 * N : 4 * N] + if ctx.etp_size == 1: + inputmats = saved_tensors[:N] + weights = saved_tensors[N : 2 * N] + origin_weights = saved_tensors[2 * N : 3 * N] + biases = saved_tensors[3 * N : 4 * N] + else: + inputmats = saved_tensors[:N] + origin_weights = saved_tensors[N : 2 * N] + biases = saved_tensors[2 * N : 3 * N] main_grads = [main_grad_func() for main_grad_func in ctx.main_grad_funcs] if ctx.cpu_offloading: @@ -330,7 +357,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], origin_weights[i] = ctx.weight_objects[i] ctx.weight_objects[i] = None - if ctx.fuse_wgrad_accumulation: + if ctx.fuse_wgrad_accumulation and ctx.etp_size == 1: for i in range(N): origin_weights[i].main_grad = main_grads[i] @@ -383,13 +410,18 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ctx.m_splits, ) - if ctx.is_first_microbatch is not None: + if ctx.etp_size > 1: + accumulate_wgrad_into_param_main_grad = False + elif ctx.is_first_microbatch is not None: accumulate_wgrad_into_param_main_grad = ( ctx.fuse_wgrad_accumulation and not ctx.is_first_microbatch ) else: accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation + if ctx.etp_size > 1: + weights = origin_weights[0].batched_all_gather_and_prefetch_bwd() + if ctx.requires_dgrad: dgrad_gemm_use_split_accumulator = _2X_ACC_DGRAD if ctx.fp8 or ctx.debug: @@ -421,6 +453,13 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], use_split_accumulator=dgrad_gemm_use_split_accumulator, ) + # Gathered weights are no longer needed after dgrad GEMM. + # For nvfp4, the NVFP4TensorStorage and its sub-tensors (scale_inv etc.) + # would otherwise survive until function return via this local ref. + if ctx.etp_size > 1: + weight_sizes = [w.size() for w in weights] + del weights + if ctx.weights_requires_grad: wgrad_gemm_use_split_accumulator = _2X_ACC_WGRAD if ctx.fp8: @@ -432,9 +471,10 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], if ctx.fuse_wgrad_accumulation: wgrad_list = main_grads else: + sizes = weight_sizes if ctx.etp_size > 1 else [w.size() for w in weights] wgrad_list = [ - torch.empty(w.size(), dtype=ctx.activation_dtype, device=ctx.device) - for w in weights + torch.empty(sz, dtype=ctx.activation_dtype, device=ctx.device) + for sz in sizes ] if ctx.save_original_input: @@ -476,7 +516,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], use_split_accumulator=wgrad_gemm_use_split_accumulator, accumulate=( accumulate_wgrad_into_param_main_grad - if not getattr(weights[0], "overwrite_main_grad", False) + if ctx.etp_size == 1 and not getattr(weights[0], "overwrite_main_grad", False) else False ), ) @@ -518,10 +558,19 @@ def handle_custom_ddp_from_mcore(weight, wgrad): wgrad = None return wgrad - wgrad_list = [ - handle_custom_ddp_from_mcore(weight, wgrad) - for weight, wgrad in zip(origin_weights, wgrad_list) - ] + if ctx.etp_size > 1: + wgrad_list = origin_weights[0].batched_wgrad_reduce_scatter(wgrad_list) + # Drop Python refs to wgrad input buffers. The async RS on rs_stream + # still holds C++ refs (via NCCL Work); those are released when + # _wait_reduce_scatter calls handle.wait() + self.handle = None. + # Without this del, main_grads keeps the tensors alive until function + # return, wasting memory during graph capture warmup. + del main_grads + elif ctx.fuse_wgrad_accumulation: + wgrad_list = [ + handle_custom_ddp_from_mcore(weight, wgrad) + for weight, wgrad in zip(origin_weights, wgrad_list) + ] else: wgrad_list = [None] * ctx.num_gemms @@ -630,6 +679,7 @@ def __init__( save_original_input: bool = False, single_grouped_parameter: bool = False, name: Optional[str] = None, + etp_group: Optional[dist_group_type] = None, ) -> None: super().__init__(name) @@ -682,6 +732,12 @@ def __init__( "Because the TP communication is handled outside of this module." ) + if etp_group is None: + self.etp_size = 1 + else: + self.etp_size = get_distributed_world_size(etp_group) + assert tp_size == 1, f"TODO(shiqingf): ETP+TP is not well supported yet." + self.parallel_mode = parallel_mode assert ( self.parallel_mode in GemmParallelModes @@ -734,6 +790,10 @@ def __init__( is_meta = torch.device(device).type == "meta" self.reset_parameters(defer_init=is_meta) + if etp_group is not None: + weight_names = [f"weight{idx}" for idx in range(self.num_gemms)] + wrap_module_params_etp(self, weight_names, etp_group, is_grouped=True) + if self.wgrad_store.delay_wgrad_compute(): for name, param in self.named_parameters(): for i in range(self.num_gemms): @@ -887,6 +947,11 @@ def forward( weight_tensors = self._get_weight_tensors() bias_tensors = [getattr(self, f"bias{i}") for i in range(self.num_gemms)] + if self.etp_size > 1: + weight_tensors[0].setup( + weight_quantizer=self._get_weight_quantizers(), + ) + quantizers = self._get_quantizers() if not debug else self._get_debug_quantizers() if debug: @@ -932,6 +997,7 @@ def forward( None, # skip_fp8_weight_update self.save_original_input, debug, + self.etp_size, ) out = linear_fn(*autograd_ctx, inp, non_tensor_args, *weight_tensors, *bias_tensors) diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index ce0581024a..6f9ac58460 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -26,6 +26,7 @@ _2X_ACC_DGRAD, _2X_ACC_WGRAD, ) +from .extended_tensor_parallelism import wrap_module_params_etp from ..quantization import FP8GlobalStateManager from ..utils import ( assert_dim_for_fp8_exec, @@ -140,6 +141,7 @@ def forward( skip_fp8_weight_update, symmetric_ar_type, debug, + etp_size, ) = non_tensor_args # NVTX label for profiling @@ -286,6 +288,15 @@ def forward( # ------------------------------------------------------ # Prepare weight tensor # ------------------------------------------------------ + + if etp_size > 1: + weight_etp_sharded = weight + weight = weight.all_gather_and_prefetch( + fwd=True, + skip_weight_cast=is_first_microbatch is False, + cast_noop_flag=skip_fp8_weight_update, + ) + weightmat = weight is_weight_param_quantized = False if fp8 or debug: @@ -400,7 +411,7 @@ def forward( nvtx_range_pop(f"{nvtx_label}.row_parallel_comm") else: out = gemm_out - out = out.view(-1, *inp_shape[1:-1], out_features) + out = out.view(-1, *inp_shape[1:-1], out.shape[-1]) # ------------------------------------------------------ # Output tensor is ready to return... # ------------------------------------------------------ @@ -463,8 +474,9 @@ def forward( tensors_to_save, tensor_objects = prepare_for_saving( inputmat, - weightmat, - weight, + # For ETP, avoid keeping the gathered weightmat in memory for memory saving. + weightmat if etp_size == 1 else None, + weight if etp_size == 1 else weight_etp_sharded, bias, ln_weight, ln_out, @@ -483,6 +495,8 @@ def forward( if hasattr(weight, "__fsdp_param__"): # MCore FSDP creates main_grad lazily before backward ctx.main_grad_func = weight.get_main_grad + elif etp_size > 1: + ctx.main_grad_func = weight_etp_sharded.get_wgrad_tensor else: ctx.main_grad_func = lambda: weight.main_grad ctx.grad_input_quantizer = grad_input_quantizer @@ -523,6 +537,7 @@ def forward( FP8GlobalStateManager.IS_FIRST_FP8_MODULE = _first_fp8_module ctx.wgrad_store = wgrad_store ctx.debug = debug + ctx.etp_size = etp_size # ------------------------------------------------------ # Cached state for backward pass is ready... @@ -560,6 +575,9 @@ def backward( rsigma, ) = restore_from_saved(ctx.tensor_objects, saved_tensors) + if ctx.etp_size > 1: + weight = origin_weight.all_gather_and_prefetch_bwd() + # Delete the references to tensor objects once they've been consumed # by the `restore_from_saved` method to construct back the actual tensors. ctx.tensor_objects = None @@ -590,7 +608,7 @@ def backward( if ctx.cpu_offloading: if ctx.grad_added_to_main_grad: origin_weight = ctx.weight_object - if ctx.requires_wgrad and ctx.fuse_wgrad_accumulation: + if ctx.requires_wgrad and ctx.fuse_wgrad_accumulation and ctx.etp_size == 1: origin_weight.main_grad = main_grad # Configure Userbuffers communication (comm+GEMM overlap) @@ -701,7 +719,7 @@ def backward( # Note: Gradient w.r.t. GEMM input (i.e. norm output). # -------------------------------------------------- - # Make sure required data is available + if isinstance(grad_output, QuantizedTensorStorage): grad_output.update_usage(rowwise_usage=True) if ctx.weight_quantizer is not None and isinstance(weight, QuantizedTensorStorage): @@ -843,7 +861,11 @@ def backward( use_split_accumulator = recipe.fp8_gemm_wgrad.use_split_accumulator # Figure out whether to output wgrad GEMM directly into main grad - if ctx.is_first_microbatch is not None: + if ctx.etp_size > 1: + # When ETP is enabled, GA is always disabled. ETP Wgrad workflow: + # allocte wgrad_out tmp buffer -> RS(wgrad_gemm) -> GradientAccumulation + accumulate_wgrad_into_param_main_grad = False + elif ctx.is_first_microbatch is not None: accumulate_wgrad_into_param_main_grad = ( ctx.fuse_wgrad_accumulation and not ctx.is_first_microbatch ) @@ -915,6 +937,9 @@ def wgrad_gemm( # Call wgrad GEMM now wgrad, grad_bias_ = wgrad_gemm(ln_out_total, grad_output) + if ctx.etp_size > 1: + wgrad = origin_weight.wgrad_reduce_scatter(wgrad) + # Update grad bias if needed if grad_bias is None: grad_bias = grad_bias_ @@ -940,7 +965,6 @@ def wgrad_gemm( dgrad = reduce_scatter_out else: dgrad = ub_obj_wgrad.get_buffer(local_chunk=True).clone() - # -------------------------------------------------- # Grad weight has been computed... # -------------------------------------------------- @@ -994,7 +1018,9 @@ def wgrad_gemm( if ctx.requires_wgrad: # Handle custom DDP from mcore. - if ctx.fuse_wgrad_accumulation and hasattr(origin_weight, "grad_added_to_main_grad"): + if ctx.etp_size > 1: + pass + elif ctx.fuse_wgrad_accumulation and hasattr(origin_weight, "grad_added_to_main_grad"): origin_weight.grad_added_to_main_grad = True if getattr(origin_weight, "zero_out_wgrad", False): wgrad = get_dummy_wgrad( @@ -1160,6 +1186,7 @@ def __init__( delay_wgrad_compute: bool = False, symmetric_ar_type: Optional[str] = None, name: Optional[str] = None, + etp_group: Optional[dist_group_type] = None, ) -> None: super().__init__(name) @@ -1190,6 +1217,10 @@ def __init__( self.set_tensor_parallel_group(tp_group) self.set_nccl_overlap_warning_if_tp() + if etp_group is None: + self.etp_size = 1 + else: + self.etp_size = get_distributed_world_size(etp_group) self.parallel_mode = parallel_mode assert ( self.parallel_mode in GemmParallelModes @@ -1382,6 +1413,10 @@ def __init__( self.reset_parameters(defer_init=device == "meta") + if etp_group is not None: + wrap_module_params_etp(self, self.weight_names, etp_group) + del weight_tensor + # For RPL, bias has to be added after TP collectives # So it cannot be fused with the GEMM if self.parallel_mode == "row" and self.apply_bias: @@ -1516,6 +1551,11 @@ def forward( # Get concatenated weight and bias tensors weight_tensor, bias_tensor = self._get_weight_and_bias_tensors() + if self.etp_size > 1: + weight_tensor.setup( + weight_quantizer=self._get_weight_quantizers(), + ) + quantizers = ( self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) if not debug @@ -1580,6 +1620,7 @@ def forward( skip_fp8_weight_update, self.symmetric_ar_type, debug, + self.etp_size, ) out = fwd_fn( *autograd_ctx, diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 31dac4d329..145c253cdc 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -25,6 +25,7 @@ _2X_ACC_WGRAD, ) from ._common import noop_cat, WeightGradStore +from .extended_tensor_parallelism import wrap_module_params_etp from ..quantization import FP8GlobalStateManager from ..utils import ( cast_if_needed, @@ -128,6 +129,7 @@ def forward( symmetric_ar_type, save_original_input, debug, + etp_size, ) = non_tensor_args # NVTX label for profiling @@ -249,6 +251,15 @@ def forward( # ------------------------------------------------------ # Prepare weight tensor # ------------------------------------------------------ + + if etp_size > 1: + weight_etp_sharded = weight + weight = weight.all_gather_and_prefetch( + fwd=True, + skip_weight_cast=is_first_microbatch is False, + cast_noop_flag=skip_fp8_weight_update, + ) + weightmat = weight if fp8 or debug: # Configure quantizer @@ -434,8 +445,8 @@ def forward( # TODO(ksivamani): Check memory usage tensors_to_save, tensor_objects = prepare_for_saving( saved_inputmat, - weightmat, - weight, + weightmat if etp_size == 1 else None, + weight if etp_size == 1 else weight_etp_sharded, bias, ) ctx.save_for_backward(*tensors_to_save) @@ -456,6 +467,8 @@ def forward( if hasattr(weight, "__fsdp_param__"): # MCore FSDP creates main_grad lazily before backward ctx.main_grad_func = weight.get_main_grad + elif etp_size > 1: + ctx.main_grad_func = weight_etp_sharded.get_wgrad_tensor else: ctx.main_grad_func = lambda: weight.main_grad @@ -486,6 +499,7 @@ def forward( if in_fp8_activation_recompute_phase(): FP8GlobalStateManager.IS_FIRST_FP8_MODULE = _first_fp8_module ctx.wgrad_store = wgrad_store + ctx.etp_size = etp_size # ------------------------------------------------------ # Cached state for backward pass is ready... @@ -522,7 +536,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], if ctx.cpu_offloading: if ctx.grad_added_to_main_grad: weight = ctx.weight_object - if ctx.requires_wgrad and ctx.fuse_wgrad_accumulation: + if ctx.requires_wgrad and ctx.fuse_wgrad_accumulation and ctx.etp_size == 1: weight.main_grad = main_grad # Gather intermediate/activation tensors if needed @@ -684,6 +698,9 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], # Compute grad input tensor # -------------------------------------------------- + if ctx.etp_size > 1: + weight_fp8 = weight.all_gather_and_prefetch_bwd() + dgrad = None dgrad_work = None if ctx.requires_dgrad: @@ -832,7 +849,9 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], use_split_accumulator = recipe.fp8_gemm_wgrad.use_split_accumulator # Figure out whether to output wgrad GEMM directly into main grad - if ctx.is_first_microbatch is not None: + if ctx.etp_size > 1: + accumulate_wgrad_into_param_main_grad = False + elif ctx.is_first_microbatch is not None: accumulate_wgrad_into_param_main_grad = ( ctx.fuse_wgrad_accumulation and not ctx.is_first_microbatch ) @@ -943,6 +962,8 @@ def wgrad_gemm( dgrad_work.wait() dgrad_work = None + if ctx.etp_size > 1: + wgrad = weight.wgrad_reduce_scatter(wgrad) if ctx.requires_wgrad: # Handle custom DDP from mcore. if ( @@ -1098,6 +1119,7 @@ def __init__( symmetric_ar_type: Optional[str] = None, save_original_input: bool = False, name: Optional[str] = None, + etp_group: Optional[dist_group_type] = None, ) -> None: super().__init__(name) @@ -1126,6 +1148,11 @@ def __init__( self.set_tensor_parallel_group(tp_group) self.set_nccl_overlap_warning_if_tp() + if etp_group is None: + self.etp_size = 1 + else: + self.etp_size = get_distributed_world_size(etp_group) + self.parallel_mode = parallel_mode assert ( self.parallel_mode in GemmParallelModes @@ -1297,6 +1324,10 @@ def __init__( self.reset_parameters(defer_init=device == "meta") + if etp_group is not None: + wrap_module_params_etp(self, self.weight_names, etp_group) + del weight_tensor + # For RPL, bias has to be added after TP collectives # So it cannot be fused with the GEMM if self.parallel_mode == "row" and self.apply_bias: @@ -1399,6 +1430,11 @@ def forward( try: weight_tensor, bias_tensor = self._get_weight_and_bias_tensors() + if self.etp_size > 1: + weight_tensor.setup( + weight_quantizer=self._get_weight_quantizers(), + ) + quantizers = ( self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) if not debug @@ -1459,6 +1495,7 @@ def forward( self.symmetric_ar_type, self.save_original_input, debug, + self.etp_size, ) out = linear_fn( *autograd_ctx,