Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
e49dd75
ETP
jiemingz Mar 5, 2026
4e0e39d
cleanup
jiemingz Mar 5, 2026
de29fac
fix post_hook not being called in certain cases
jiemingz Mar 10, 2026
9b846dc
Cudagraph Support
jiemingz Mar 19, 2026
9cb1379
debug: make etp link table log human readable.
fanshiqing Mar 30, 2026
a5ef675
doc: add README_ETP.md
fanshiqing Mar 30, 2026
92fc0f0
debug: add default and meaningful nvtx_label for ETP (batched) AG/RS …
fanshiqing Mar 30, 2026
b639271
doc: udpate README
fanshiqing Mar 31, 2026
94e62d1
ETP+CG: fix '--overlap-grad-reduce --overlap-param-gather'
fanshiqing Apr 3, 2026
f4b5a5e
ETP padding: fix stripping for rowwise_scale_inv and columnwise_scale…
fanshiqing Apr 3, 2026
89d8ae7
ETP: add UTs and doc update.
fanshiqing Apr 3, 2026
1d771ff
import fix
fanshiqing Apr 7, 2026
fd55ede
move ag init to first pass
jiemingz Apr 9, 2026
62379f5
ETP+CG: 2-chain(dense+expert, no cross link prefetching) + shared ag/…
fanshiqing Apr 9, 2026
ba909cc
fix the case when ETP_Config.weight_prefetch is False.
fanshiqing Apr 10, 2026
fd65c96
ETP+emb/output layers: remove these two layers from the prefetch chain.
fanshiqing Apr 13, 2026
6504611
ET+CG mem fix1: use pooled buffers for both async and sync gathers t…
fanshiqing Apr 14, 2026
14db814
ETP+CG mem fix2: fix wgrad tensor retention and eliminate redundant A…
fanshiqing Apr 14, 2026
e9acc1b
ETP+CG mem fix3: release gathered expert weights after dgrad GEMM in …
fanshiqing Apr 14, 2026
96515f1
[Conservative] fix ETP+CG+DDPOverlaping hang: serialize DDP RS and EE…
fanshiqing Apr 14, 2026
d78bd53
ETP+CG+DDP final fix: ETP: restore register_grad_accum_hook + _finali…
fanshiqing Apr 15, 2026
1fd09f5
ETP+CG mem fix4: del main_grads after batched_wgrad_reduce_scatter
fanshiqing Apr 15, 2026
66fb81c
update doc
fanshiqing Apr 16, 2026
e09983c
ETP+CG: re-enable bwd ETP RS overlapping across Graphs.
fanshiqing Apr 16, 2026
d72bcce
Revert "ETP+CG: re-enable bwd ETP RS overlapping across Graphs."
fanshiqing Apr 17, 2026
7cc86fd
ETP: fix iter-2 NaN + unbounded wgrad pool growth; partition streams
fanshiqing Apr 19, 2026
a652a63
ETP+nvfp4: coalescing amax reduction for fprop groupedgemm
fanshiqing Apr 21, 2026
c0538ec
remap quantized params
jiemingz Apr 21, 2026
7949e50
remap quantized param patch: (1) Remap quantized params into the CG m…
fanshiqing Apr 22, 2026
fcef590
ETP+CG: launch etp async rs on rs_stream to match isse-site invariant.
fanshiqing Apr 22, 2026
cd88d3b
ETP+NVFP4: fused multi-tensor amax kernel
fanshiqing Apr 23, 2026
f7a08f7
ETP: cache hot-path lookups in ETP to reduce python overhead.
fanshiqing Apr 24, 2026
9f614f4
ETP: disable check_param_states by default
fanshiqing Apr 24, 2026
eace39d
ETP divergence fix: revert async AG issue-site stream wrapper from 7c…
fanshiqing Apr 24, 2026
4ece705
ETP+CG: fix for flaky NaN issue at scale: async AG/RS issue on side s…
fanshiqing Apr 25, 2026
a4ce839
ETP: disable check_param_states by default
fanshiqing Apr 25, 2026
f446e02
ETP+CG: enable cross-graph RS overlap — main_grad.add_ on rs_stream
fanshiqing Apr 27, 2026
464c289
wgrad accum fusion
jiemingz Apr 23, 2026
06a2756
minor fix for wgrad accum fusion
fanshiqing Apr 27, 2026
e107b05
ETP+mxfp8: reject coalesced-amax path for non-NVFP4 quantizers
fanshiqing Apr 28, 2026
03bb9fb
ETP: add debug_numerics instrumentation for NaN/Inf triage
fanshiqing Apr 28, 2026
550f074
ETP+mxfp8 divergence fix: disable GEMM-swizzled scales for all-gather…
fanshiqing Apr 28, 2026
520251c
ETP: pad full tensor before sharding instead of per-rank on-the-fly
fanshiqing Apr 28, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
784 changes: 784 additions & 0 deletions docs/README_ETP.md

Large diffs are not rendered by default.

Binary file added docs/etp/etp_ep_nt6_schedule_bf16.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1,411 changes: 1,411 additions & 0 deletions tests/pytorch/distributed/test_etp.py

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions transformer_engine/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ list(APPEND transformer_engine_cuda_sources
recipe/current_scaling.cu
recipe/delayed_scaling.cu
recipe/fp8_block_scaling.cu
recipe/multi_amax.cu
comm_gemm_overlap/userbuffers/userbuffers.cu)

list(APPEND transformer_engine_cuda_arch_specific_sources
Expand Down
20 changes: 20 additions & 0 deletions transformer_engine/common/include/transformer_engine/recipe.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,26 @@ void nvte_compute_amax(const NVTETensor input, NVTETensor output, cudaStream_t s
void nvte_compute_amax_with_config(const NVTETensor input, NVTETensor output,
const NVTEQuantizationConfig config, cudaStream_t stream);

/*! \brief Compute amax for a list of independent tensors in a single kernel launch.
*
* Unlike nvte_group_amax (which requires a single contiguous input split along dim 0),
* this API accepts arrays of independent input tensors, each with its own allocation.
* Designed for the ETP grouped-experts case where per-expert weights live in separate
* buffers. For each i in [0, num_tensors), computes amax(inputs[i]) and writes it to
* outputs[i]'s amax buffer. outputs[i] must be an FP8 per-tensor scaling or NVFP4 1D
* scaling tensor. All inputs must share the same dtype. If the list exceeds the
* per-launch batch capacity, it is internally chunked.
*
* \param[in] inputs Array of input tensors (unquantized). Size num_tensors.
* \param[in,out] outputs Array of output tensors. Only the amax is updated.
* Size num_tensors.
* \param[in] num_tensors Number of tensors.
* \param[in] config Quantization configuration (for noop_tensor). May be NULL.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_multi_compute_amax(const NVTETensor *inputs, NVTETensor *outputs, size_t num_tensors,
const NVTEQuantizationConfig config, cudaStream_t stream);

/*! \brief Update an FP8 tensor's scale based on its amax.
*
* This is only supported for FP8 tensors with per-tensor scaling.
Expand Down
274 changes: 274 additions & 0 deletions transformer_engine/common/recipe/multi_amax.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,274 @@
/*************************************************************************
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/

#include <transformer_engine/recipe.h>

#include <algorithm>
#include <vector>

#include "../common.h"
#include "../util/logging.h"
#include "../util/vectorized_pointwise.h"
#include "recipe_common.cuh"

namespace transformer_engine {
namespace {

constexpr int multi_amax_kernel_threads = 512;
// Per-launch capacity. kMaxTensorsPerBatch * ~40 bytes per slot keeps the args
// struct within the 4KB kernel parameter limit with comfortable headroom.
constexpr int kMaxTensorsPerBatch = 64;

struct MultiAmaxArgs {
const void *input_list[kMaxTensorsPerBatch];
void *output_rowwise_amax_list[kMaxTensorsPerBatch];
void *output_columnwise_amax_list[kMaxTensorsPerBatch];
size_t input_numel[kMaxTensorsPerBatch];
size_t num_aligned_elements[kMaxTensorsPerBatch];
int num_tensors;
};

// Zero out every output amax slot (rowwise + columnwise, deduped) in a single launch.
// Respects the noop_ptr contract shared with the single-tensor amax path.
__launch_bounds__(multi_amax_kernel_threads) __global__
void MultiZeroAmaxKernel(MultiAmaxArgs args, const float *noop_ptr) {
if (noop_ptr != nullptr && noop_ptr[0] == 1.0f) {
return;
}
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
for (; tid < args.num_tensors; tid += stride) {
float *rw = static_cast<float *>(args.output_rowwise_amax_list[tid]);
float *cw = static_cast<float *>(args.output_columnwise_amax_list[tid]);
if (rw != nullptr) {
*rw = 0.0f;
}
if (cw != nullptr && cw != rw) {
*cw = 0.0f;
}
}
}

// Per-tensor amax with one block-strip per tensor. blockIdx.y selects the
// tensor; blockIdx.x is the work chunk within that tensor. Each block
// vector-loads the tensor, reduces across threads, and atomicMaxFloats the
// result into BOTH output amax slots (rowwise + columnwise, deduped). This
// subsumes the per-expert D2D copy that the single-tensor path does after the
// amax kernel.
template <int nvec, bool aligned, typename InputType>
__launch_bounds__(multi_amax_kernel_threads) __global__
void MultiAmaxKernel(MultiAmaxArgs args, const float *noop_ptr) {
if (noop_ptr != nullptr && noop_ptr[0] == 1.0f) {
return;
}

const int t_idx = blockIdx.y;
if (t_idx >= args.num_tensors) {
return;
}

const InputType *input = static_cast<const InputType *>(args.input_list[t_idx]);
const size_t N = args.input_numel[t_idx];
if (N == 0) {
return;
}
const size_t M = args.num_aligned_elements[t_idx];

VectorizedLoader<InputType, nvec, aligned> loader(input, N);
InputType max = InputType{0.f};
const int warp_id = threadIdx.x / THREADS_PER_WARP;

for (size_t tid = blockIdx.x * blockDim.x + threadIdx.x; tid < M;
tid += gridDim.x * blockDim.x) {
loader.load(tid, N);
#pragma unroll
for (int i = 0; i < nvec; ++i) {
const InputType val = static_cast<InputType>(loader.separate()[i]);
__builtin_assume(max >= InputType{0.f});
if constexpr (std::is_same_v<InputType, __nv_bfloat16>) {
#if __CUDA_ARCH__ >= 800
max = __hmax(__habs(val), max);
#else
max = static_cast<__nv_bfloat16>(
fmaxf(fabsf(static_cast<float>(val)), static_cast<float>(max)));
#endif
} else if constexpr (std::is_same_v<InputType, __half>) {
max = __hmax(__habs(val), max);
} else {
max = fmaxf(fabsf(val), max);
}
}
}

// Reduce amax over block.
max = reduce_max<multi_amax_kernel_threads / THREADS_PER_WARP>(max, warp_id);
if (threadIdx.x == 0) {
float *rw = static_cast<float *>(args.output_rowwise_amax_list[t_idx]);
float *cw = static_cast<float *>(args.output_columnwise_amax_list[t_idx]);
if (rw != nullptr) {
atomicMaxFloat(rw, static_cast<float>(max));
}
if (cw != nullptr && cw != rw) {
atomicMaxFloat(cw, static_cast<float>(max));
}
}
}

template <typename InputType>
void launch_multi_amax_batch(const MultiAmaxArgs &args, size_t max_numel, Alignment align,
const float *noop_ptr, cudaStream_t stream) {
// Zero all amax outputs in one launch.
{
constexpr int threads = multi_amax_kernel_threads;
const int num_blocks = std::max(1, DIVUP(args.num_tensors, threads));
MultiZeroAmaxKernel<<<num_blocks, threads, 0, stream>>>(args, noop_ptr);
NVTE_CHECK_CUDA(cudaGetLastError());
}

if (max_numel == 0) {
return;
}

// Grid: y = tensor index, x = work chunks within the largest tensor. Blocks
// that exceed a shorter tensor's aligned element count bail out via the
// bounds check inside the kernel.
constexpr int nvec = 32 / sizeof(InputType);
constexpr size_t threads = multi_amax_kernel_threads;
const size_t max_aligned = (max_numel + nvec - 1) / nvec;
size_t num_blocks_x = DIVUP(max_aligned, threads);
constexpr size_t max_blocks = 65535;
num_blocks_x = std::min(num_blocks_x, max_blocks);
num_blocks_x = std::max<size_t>(num_blocks_x, 1);
dim3 grid(num_blocks_x, static_cast<unsigned int>(args.num_tensors), 1);

switch (align) {
case Alignment::SAME_ALIGNED:
MultiAmaxKernel<nvec, true, InputType>
<<<grid, threads, 0, stream>>>(args, noop_ptr);
break;
case Alignment::SAME_UNALIGNED:
MultiAmaxKernel<nvec, false, InputType>
<<<grid, threads, 0, stream>>>(args, noop_ptr);
break;
case Alignment::DIFFERENT:
// Heterogeneous alignment across tensors — fall back to nvec=1, aligned=true path
// which is safe for any pointer alignment.
MultiAmaxKernel<1, true, InputType>
<<<grid, threads, 0, stream>>>(args, noop_ptr);
break;
}
NVTE_CHECK_CUDA(cudaGetLastError());
}

// Fill one MultiAmaxArgs batch from a slice of the full input/output list.
// Returns (max_numel in this batch, worst-case alignment across the batch).
template <typename InputType>
std::pair<size_t, Alignment> build_batch_args(const std::vector<Tensor *> &inputs,
const std::vector<Tensor *> &outputs, size_t start,
size_t count, MultiAmaxArgs &args) {
constexpr int nvec = 32 / sizeof(InputType);
size_t max_numel = 0;
// SAME_ALIGNED is the most optimistic; degrade to SAME_UNALIGNED if any
// tensor is merely same-layout but unaligned, to DIFFERENT if alignment
// varies across tensors.
Alignment batch_align = Alignment::SAME_ALIGNED;
for (size_t i = 0; i < count; ++i) {
const Tensor &inp = *inputs[start + i];
Tensor &out = *outputs[start + i];
const size_t N = inp.data.numel();
void *rw_ptr = out.amax.dptr;
void *cw_ptr = out.columnwise_amax.dptr;

args.input_list[i] = inp.data.dptr;
args.output_rowwise_amax_list[i] = rw_ptr;
args.output_columnwise_amax_list[i] = cw_ptr;
args.input_numel[i] = N;
args.num_aligned_elements[i] = get_num_aligned_elements(inp.data.dptr, N, nvec,
sizeof(InputType));
max_numel = std::max(max_numel, N);

// Fold this tensor's alignment into the batch decision. CheckAlignment on a
// single pointer yields SAME_ALIGNED or SAME_UNALIGNED; mixing the two across
// tensors means heterogeneous — switch to the DIFFERENT fall-back.
if (N > 0) {
Alignment a = CheckAlignment(N, nvec, static_cast<const InputType *>(inp.data.dptr));
if (batch_align == Alignment::SAME_ALIGNED && a == Alignment::SAME_UNALIGNED) {
batch_align = Alignment::SAME_UNALIGNED;
} else if (batch_align == Alignment::SAME_UNALIGNED && a == Alignment::SAME_ALIGNED) {
batch_align = Alignment::SAME_UNALIGNED;
} else if (a == Alignment::DIFFERENT) {
batch_align = Alignment::DIFFERENT;
}
}
}
args.num_tensors = static_cast<int>(count);
return {max_numel, batch_align};
}

void multi_compute_amax_impl(const NVTETensor *inputs_, NVTETensor *outputs_, size_t num_tensors,
const NVTEQuantizationConfig config_, cudaStream_t stream) {
if (num_tensors == 0) {
return;
}
NVTE_CHECK(inputs_ != nullptr, "nvte_multi_compute_amax: inputs is NULL");
NVTE_CHECK(outputs_ != nullptr, "nvte_multi_compute_amax: outputs is NULL");

// Convert, validate, collect into plain vectors.
std::vector<Tensor *> inputs(num_tensors);
std::vector<Tensor *> outputs(num_tensors);
DType input_dtype;
for (size_t i = 0; i < num_tensors; ++i) {
inputs[i] = convertNVTETensorCheck(inputs_[i]);
outputs[i] = convertNVTETensorCheck(outputs_[i]);
const auto &inp = *inputs[i];
auto &out = *outputs[i];
NVTE_CHECK(inp.scaling_mode == NVTE_DELAYED_TENSOR_SCALING,
"nvte_multi_compute_amax: input[", i,
"] must be unquantized, got scaling_mode=", to_string(inp.scaling_mode));
NVTE_CHECK(!is_fp8_dtype(inp.data.dtype),
"nvte_multi_compute_amax: input[", i,
"] must be unquantized, got dtype=", to_string(inp.data.dtype));
if (i == 0) {
input_dtype = inp.data.dtype;
} else {
NVTE_CHECK(inp.data.dtype == input_dtype,
"nvte_multi_compute_amax: all inputs must share dtype; input[0]=",
to_string(input_dtype), ", input[", i, "]=", to_string(inp.data.dtype));
}
NVTE_CHECK(out.scaling_mode == NVTE_DELAYED_TENSOR_SCALING ||
out.scaling_mode == NVTE_NVFP4_1D_SCALING,
"nvte_multi_compute_amax: output[", i, "] must be FP8 per-tensor or NVFP4 1D");
NVTE_CHECK(out.amax.dptr != nullptr || out.columnwise_amax.dptr != nullptr,
"nvte_multi_compute_amax: output[", i, "] has no amax buffer");
}

const float *noop_ptr = nullptr;
if (config_ != nullptr) {
const QuantizationConfig *config_cpp = reinterpret_cast<const QuantizationConfig *>(config_);
const NVTETensor noop = config_cpp->noop_tensor;
noop_ptr = reinterpret_cast<float *>(
(noop != nullptr ? convertNVTETensorCheck(noop)->data.dptr : nullptr));
}

// Chunk across kMaxTensorsPerBatch launches (single launch in the common 8-expert case).
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input_dtype, IType, {
for (size_t start = 0; start < num_tensors; start += kMaxTensorsPerBatch) {
const size_t count = std::min<size_t>(kMaxTensorsPerBatch, num_tensors - start);
MultiAmaxArgs args = {};
auto [max_numel, batch_align] = build_batch_args<IType>(inputs, outputs, start, count, args);
launch_multi_amax_batch<IType>(args, max_numel, batch_align, noop_ptr, stream);
}
}); // NOLINT(*)
}

} // anonymous namespace
} // namespace transformer_engine

void nvte_multi_compute_amax(const NVTETensor *inputs, NVTETensor *outputs, size_t num_tensors,
const NVTEQuantizationConfig config, cudaStream_t stream) {
NVTE_API_CALL(nvte_multi_compute_amax);
transformer_engine::multi_compute_amax_impl(inputs, outputs, num_tensors, config, stream);
}
21 changes: 20 additions & 1 deletion transformer_engine/pytorch/csrc/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -365,11 +365,30 @@ class NVFP4Quantizer : public Quantizer {
*/
void quantize_with_amax(TensorWrapper& input, TensorWrapper& out);

/*! @brief Compute (and D2D fill) local amax only — no cast, no allreduce.
*
* Writes the local amax into out's rowwise and/or columnwise amax
* buffers. Callers are expected to perform a coalesced allreduce
* across the amax reduction group afterwards, then invoke
* quantize_cast_only to finish the cast with the reduced amax.
*/
void compute_amax_only(const TensorWrapper& input, TensorWrapper& out);

/*! @brief Cast to NVFP4 assuming amax already reduced externally.
*
* Skips both local amax compute and the internal amax allreduce.
* Callers must guarantee out's amax buffers already hold the reduced
* amax (e.g. via compute_amax_only + allreduce_coalesced).
*/
void quantize_cast_only(const TensorWrapper& input, TensorWrapper& out,
const std::optional<TensorWrapper>& noop_flag = std::nullopt);

std::vector<size_t> get_scale_shape(const std::vector<size_t>& shape, bool columnwise) const;

private:
void quantize_impl(const TensorWrapper& input, TensorWrapper& out,
const std::optional<TensorWrapper>& noop_flag, bool compute_amax);
const std::optional<TensorWrapper>& noop_flag, bool compute_amax,
bool skip_amax_reduction = false);
};

std::unique_ptr<Quantizer> convert_quantizer(py::handle quantizer);
Expand Down
15 changes: 15 additions & 0 deletions transformer_engine/pytorch/csrc/extensions.h
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,21 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w
py::object quantize(const at::Tensor &tensor, py::handle quantizer, const py::object &output,
std::optional<at::Tensor> noop_flag);

// NVFP4-only split-phase quantize: compute amax, coalesce allreduce externally, then cast.
py::object compute_amax_nvfp4(const at::Tensor &tensor, py::handle quantizer,
const py::object &output);
py::object quantize_cast_only_nvfp4(const at::Tensor &tensor, py::handle quantizer,
const py::object &output,
std::optional<at::Tensor> noop_flag);

// NVFP4-only multi-tensor amax: fuses N per-expert (zero_amax + amax + D2D replicate)
// chains into a single pair of kernel launches (one multi-zero + one multi-amax) that
// writes amax into every output's rowwise AND columnwise buffers. Outputs must be
// pre-allocated; amax is written in place, no return.
void compute_multi_amax_nvfp4(const std::vector<at::Tensor> &tensor_list,
std::vector<py::handle> quantizer_list,
const std::vector<py::object> &output_list);

py::object dequantize(const py::handle &input, DType otype);

py::object group_quantize(const at::Tensor &tensor, py::handle quantizer, const size_t num_tensors,
Expand Down
Loading