From 4eb228a228e0621081f03e751c6137aec761e719 Mon Sep 17 00:00:00 2001 From: Xin Yao Date: Thu, 29 Jan 2026 07:14:49 +0000 Subject: [PATCH 1/6] fused scaling and unscaling of bf16 momentum Signed-off-by: Xin Yao --- transformer_engine/common/common.h | 15 ++ .../common/multi_tensor/adam.cu | 222 +++++++++++------- .../pytorch/optimizers/fused_adam.py | 32 ++- 3 files changed, 176 insertions(+), 93 deletions(-) diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index 970b7aef6c..eb58f5b4a1 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -717,6 +717,21 @@ struct TypeInfo { NVTE_ERROR("Invalid type."); \ } +#define TRANSFORMER_ENGINE_TYPE_SWITCH_FP32_BF16(dtype, type, ...) \ + switch (dtype) { \ + using namespace transformer_engine; \ + case DType::kFloat32: { \ + using type = float; \ + { __VA_ARGS__ } \ + } break; \ + case DType::kBFloat16: { \ + using type = bf16; \ + { __VA_ARGS__ } \ + } break; \ + default: \ + NVTE_ERROR("Invalid type, expected Float32 or BFloat16."); \ + } + // Add a pack_size argument to select the packed type for FP4 #define TRANSFORMER_ENGINE_TYPE_SWITCH_FP4x2_ONLY(dtype, pack_size, type, ...) \ switch (dtype) { \ diff --git a/transformer_engine/common/multi_tensor/adam.cu b/transformer_engine/common/multi_tensor/adam.cu index 5d89179c44..29a073be84 100644 --- a/transformer_engine/common/multi_tensor/adam.cu +++ b/transformer_engine/common/multi_tensor/adam.cu @@ -49,7 +49,7 @@ struct FP8Data { template <> struct FP8Data {}; -template +template struct AdamFunctorMaster { static constexpr bool is_fp8_type = is_fp8::value; @@ -79,10 +79,10 @@ struct AdamFunctorMaster { PARAM_T *p = reinterpret_cast(tl.addresses[1][tensor_loc]); p += chunk_idx * chunk_size; - FULL_T *m = reinterpret_cast(tl.addresses[2][tensor_loc]); + MOMENT_T *m = reinterpret_cast(tl.addresses[2][tensor_loc]); m += chunk_idx * chunk_size; - FULL_T *v = reinterpret_cast(tl.addresses[3][tensor_loc]); + MOMENT_T *v = reinterpret_cast(tl.addresses[3][tensor_loc]); v += chunk_idx * chunk_size; FULL_T *p_master = reinterpret_cast(tl.addresses[4][tensor_loc]); @@ -147,8 +147,8 @@ struct AdamFunctorMaster { int i = i_start + threadIdx.x + ii * blockDim.x; if (i < n && i < chunk_size) { p_master[i] = static_cast(r_p[ii]); - m[i] = static_cast(r_m[ii]); - v[i] = static_cast(r_v[ii]); + m[i] = static_cast(r_m[ii]); + v[i] = static_cast(r_v[ii]); if constexpr (is_fp8_type) { __builtin_assume(fp8_data.max >= 0); fp8_data.max = fmaxf(fabsf(r_p[ii]), fp8_data.max); @@ -175,7 +175,7 @@ struct AdamFunctorMaster { } }; -template +template struct AdamFunctorMasterParamRemainder { __device__ __forceinline__ void operator()(index_t chunk_size, volatile int *noop_gmem, TensorListMetadata<5> &tl, // NOLINT(*) @@ -194,10 +194,10 @@ struct AdamFunctorMasterParamRemainder { int16_t *p = reinterpret_cast(tl.addresses[1][tensor_loc]); p += chunk_idx * chunk_size; - FULL_T *m = reinterpret_cast(tl.addresses[2][tensor_loc]); + MOMENT_T *m = reinterpret_cast(tl.addresses[2][tensor_loc]); m += chunk_idx * chunk_size; - FULL_T *v = reinterpret_cast(tl.addresses[3][tensor_loc]); + MOMENT_T *v = reinterpret_cast(tl.addresses[3][tensor_loc]); v += chunk_idx * chunk_size; int16_t *p_remainder = reinterpret_cast(tl.addresses[4][tensor_loc]); @@ -283,15 +283,15 @@ struct AdamFunctorMasterParamRemainder { p_remainder[i] = local_p_rem[ii]; p[i] = local_p[ii]; - m[i] = static_cast(r_m[ii]); - v[i] = static_cast(r_v[ii]); + m[i] = static_cast(r_m[ii]); + v[i] = static_cast(r_v[ii]); } } } } }; -template +template struct AdamFunctor { __device__ __forceinline__ void operator()(index_t chunk_size, volatile int *noop_gmem, TensorListMetadata<4> &tl, // NOLINT(*) @@ -317,10 +317,10 @@ struct AdamFunctor { PARAM_T *p = reinterpret_cast(tl.addresses[1][tensor_loc]); p += chunk_idx * chunk_size; - FULL_T *m = reinterpret_cast(tl.addresses[2][tensor_loc]); + MOMENT_T *m = reinterpret_cast(tl.addresses[2][tensor_loc]); m += chunk_idx * chunk_size; - FULL_T *v = reinterpret_cast(tl.addresses[3][tensor_loc]); + MOMENT_T *v = reinterpret_cast(tl.addresses[3][tensor_loc]); v += chunk_idx * chunk_size; n -= chunk_idx * chunk_size; @@ -372,15 +372,15 @@ struct AdamFunctor { int i = i_start + threadIdx.x + ii * blockDim.x; if (i < n && i < chunk_size) { p[i] = static_cast(r_p[ii]); - m[i] = static_cast(r_m[ii]); - v[i] = static_cast(r_v[ii]); + m[i] = static_cast(r_m[ii]); + v[i] = static_cast(r_v[ii]); } } } } }; -template +template struct AdamCapturableFunctor { __device__ __forceinline__ void operator()(int chunk_size, volatile int *noop_gmem, TensorListMetadata<4> &tl, // NOLINT(*) @@ -410,10 +410,10 @@ struct AdamCapturableFunctor { T *p = reinterpret_cast(tl.addresses[1][tensor_loc]); p += chunk_idx * chunk_size; - FULL_T *m = reinterpret_cast(tl.addresses[2][tensor_loc]); + MOMENT_T *m = reinterpret_cast(tl.addresses[2][tensor_loc]); m += chunk_idx * chunk_size; - FULL_T *v = reinterpret_cast(tl.addresses[3][tensor_loc]); + MOMENT_T *v = reinterpret_cast(tl.addresses[3][tensor_loc]); v += chunk_idx * chunk_size; n -= chunk_idx * chunk_size; @@ -466,15 +466,15 @@ struct AdamCapturableFunctor { int i = i_start + threadIdx.x + ii * blockDim.x; if (i < n && i < chunk_size) { p[i] = static_cast(r_p[ii]); - m[i] = static_cast(r_m[ii]); - v[i] = static_cast(r_v[ii]); + m[i] = static_cast(r_m[ii]); + v[i] = static_cast(r_v[ii]); } } } } }; -template +template struct AdamCapturableMasterFunctor { __device__ __forceinline__ void operator()(int chunk_size, volatile int *noop_gmem, TensorListMetadata<5> &tl, // NOLINT(*) @@ -504,10 +504,10 @@ struct AdamCapturableMasterFunctor { T *p = reinterpret_cast(tl.addresses[1][tensor_loc]); p += chunk_idx * chunk_size; - FULL_T *m = reinterpret_cast(tl.addresses[2][tensor_loc]); + MOMENT_T *m = reinterpret_cast(tl.addresses[2][tensor_loc]); m += chunk_idx * chunk_size; - FULL_T *v = reinterpret_cast(tl.addresses[3][tensor_loc]); + MOMENT_T *v = reinterpret_cast(tl.addresses[3][tensor_loc]); v += chunk_idx * chunk_size; FULL_T *p_master = reinterpret_cast(tl.addresses[4][tensor_loc]); @@ -564,8 +564,8 @@ struct AdamCapturableMasterFunctor { if (i < n && i < chunk_size) { p[i] = static_cast(r_p[ii]); p_master[i] = static_cast(r_p[ii]); - m[i] = static_cast(r_m[ii]); - v[i] = static_cast(r_v[ii]); + m[i] = static_cast(r_m[ii]); + v[i] = static_cast(r_v[ii]); } } } @@ -606,12 +606,17 @@ void multi_tensor_adam_cuda(int chunk_size, Tensor noop_flag, NVTE_CHECK(tensor_lists[1][j]->dtype() == p_in_type_te, "Param tensor ", j, " has dtype=", to_string(tensor_lists[1][j]->dtype()), ", but expected dtype=", to_string(p_in_type_te)); - NVTE_CHECK(tensor_lists[2][j]->dtype() == DType::kFloat32, "First moment tensor ", j, - " has dtype=", to_string(tensor_lists[2][j]->dtype()), - ", but expected dtype=", to_string(DType::kFloat32)); - NVTE_CHECK(tensor_lists[3][j]->dtype() == DType::kFloat32, "Second moment tensor ", j, - " has dtype=", to_string(tensor_lists[3][j]->dtype()), - ", but expected dtype=", to_string(DType::kFloat32)); + { + const bool m_is_fp32 = tensor_lists[2][j]->dtype() == DType::kFloat32; + const bool m_is_bf16 = tensor_lists[2][j]->dtype() == DType::kBFloat16; + const bool v_is_fp32 = tensor_lists[3][j]->dtype() == DType::kFloat32; + const bool v_is_bf16 = tensor_lists[3][j]->dtype() == DType::kBFloat16; + NVTE_CHECK((m_is_fp32 && v_is_fp32) || (m_is_bf16 && v_is_bf16), + "First and second moment tensors must both be Float32 or both be BFloat16, but " + "tensor ", + j, " has first moment dtype=", to_string(tensor_lists[2][j]->dtype()), + " and second moment dtype=", to_string(tensor_lists[3][j]->dtype())); + } if (num_tensor_lists == 5) { NVTE_CHECK(tensor_lists[4][j]->dtype() == DType::kFloat32, "Master param tensor ", j, " has dtype=", to_string(tensor_lists[4][j]->dtype()), @@ -633,6 +638,9 @@ void multi_tensor_adam_cuda(int chunk_size, Tensor noop_flag, } } + // Get moment dtype (m and v have the same dtype, already validated above) + const auto moment_type_te = tensor_lists[2][0]->dtype(); + // Launch kernel if (requires_64bit_indexing) { if (num_tensor_lists == 4) { @@ -641,22 +649,26 @@ void multi_tensor_adam_cuda(int chunk_size, Tensor noop_flag, p_in_type_te, p_in_type, TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( g_in_type_te, g_in_type, - multi_tensor_apply<4>((int64_t)BLOCK_SIZE, (int64_t)chunk_size, noop_flag, - tensor_lists, - AdamFunctor(), stream, - beta1, beta2, bias_correction1, bias_correction2, epsilon, lr, - (adamMode_t)mode, weight_decay);)); + TRANSFORMER_ENGINE_TYPE_SWITCH_FP32_BF16( + moment_type_te, moment_type, + multi_tensor_apply<4>( + (int64_t)BLOCK_SIZE, (int64_t)chunk_size, noop_flag, tensor_lists, + AdamFunctor(), stream, + beta1, beta2, bias_correction1, bias_correction2, epsilon, lr, + (adamMode_t)mode, weight_decay);))); } else { // g, p, m, v, p_master TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( p_in_type_te, p_in_type, TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( g_in_type_te, g_in_type, - multi_tensor_apply<5>((int64_t)BLOCK_SIZE, (int64_t)chunk_size, noop_flag, - tensor_lists, - AdamFunctorMaster(), - stream, beta1, beta2, bias_correction1, bias_correction2, - epsilon, lr, (adamMode_t)mode, weight_decay);)); + TRANSFORMER_ENGINE_TYPE_SWITCH_FP32_BF16( + moment_type_te, moment_type, + multi_tensor_apply<5>( + (int64_t)BLOCK_SIZE, (int64_t)chunk_size, noop_flag, tensor_lists, + AdamFunctorMaster(), + stream, beta1, beta2, bias_correction1, bias_correction2, epsilon, lr, + (adamMode_t)mode, weight_decay);))); } } else { if (num_tensor_lists == 4) { @@ -665,20 +677,26 @@ void multi_tensor_adam_cuda(int chunk_size, Tensor noop_flag, p_in_type_te, p_in_type, TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( g_in_type_te, g_in_type, - multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, - AdamFunctor(), stream, - beta1, beta2, bias_correction1, bias_correction2, epsilon, lr, - (adamMode_t)mode, weight_decay);)); + TRANSFORMER_ENGINE_TYPE_SWITCH_FP32_BF16( + moment_type_te, moment_type, + multi_tensor_apply<4>( + BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, + AdamFunctor(), stream, + beta1, beta2, bias_correction1, bias_correction2, epsilon, lr, + (adamMode_t)mode, weight_decay);))); } else { // g, p, m, v, p_master TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( p_in_type_te, p_in_type, TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( g_in_type_te, g_in_type, - multi_tensor_apply<5>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, - AdamFunctorMaster(), - stream, beta1, beta2, bias_correction1, bias_correction2, - epsilon, lr, (adamMode_t)mode, weight_decay);)); + TRANSFORMER_ENGINE_TYPE_SWITCH_FP32_BF16( + moment_type_te, moment_type, + multi_tensor_apply<5>( + BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, + AdamFunctorMaster(), + stream, beta1, beta2, bias_correction1, bias_correction2, epsilon, lr, + (adamMode_t)mode, weight_decay);))); } } NVTE_CHECK_CUDA(cudaGetLastError()); @@ -716,24 +734,35 @@ void multi_tensor_adam_param_remainder_cuda(int chunk_size, Tensor noop_flag, NVTE_CHECK(tensor_lists[1][j]->dtype() == DType::kBFloat16, "Param tensor ", j, " has dtype=", to_string(tensor_lists[1][j]->dtype()), ", but expected dtype=", to_string(DType::kBFloat16)); - NVTE_CHECK(tensor_lists[2][j]->dtype() == DType::kFloat32, "First moment tensor ", j, - " has dtype=", to_string(tensor_lists[2][j]->dtype()), - ", but expected dtype=", to_string(DType::kFloat32)); - NVTE_CHECK(tensor_lists[3][j]->dtype() == DType::kFloat32, "Second moment tensor ", j, - " has dtype=", to_string(tensor_lists[3][j]->dtype()), - ", but expected dtype=", to_string(DType::kFloat32)); + { + const bool m_is_fp32 = tensor_lists[2][j]->dtype() == DType::kFloat32; + const bool m_is_bf16 = tensor_lists[2][j]->dtype() == DType::kBFloat16; + const bool v_is_fp32 = tensor_lists[3][j]->dtype() == DType::kFloat32; + const bool v_is_bf16 = tensor_lists[3][j]->dtype() == DType::kBFloat16; + NVTE_CHECK((m_is_fp32 && v_is_fp32) || (m_is_bf16 && v_is_bf16), + "First and second moment tensors must both be Float32 or both be BFloat16, but " + "tensor ", + j, " has first moment dtype=", to_string(tensor_lists[2][j]->dtype()), + " and second moment dtype=", to_string(tensor_lists[3][j]->dtype())); + } NVTE_CHECK(tensor_lists[4][j]->dtype() == DType::kInt16, "Param remainder tensor ", j, " has dtype=", to_string(tensor_lists[4][j]->dtype()), ", but expected dtype=", to_string(DType::kInt16)); } + // Get moment dtype (m and v have the same dtype, already validated above) + const auto moment_type_te = tensor_lists[2][0]->dtype(); + // Launch kernel TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( g_in_type_te, g_in_type, - multi_tensor_apply<5>((int64_t)BLOCK_SIZE, (int64_t)chunk_size, noop_flag, tensor_lists, - AdamFunctorMasterParamRemainder(), stream, - beta1, beta2, bias_correction1, bias_correction2, epsilon, lr, - (adamMode_t)mode, weight_decay);); + TRANSFORMER_ENGINE_TYPE_SWITCH_FP32_BF16( + moment_type_te, moment_type, + multi_tensor_apply<5>( + (int64_t)BLOCK_SIZE, (int64_t)chunk_size, noop_flag, tensor_lists, + AdamFunctorMasterParamRemainder(), stream, + beta1, beta2, bias_correction1, bias_correction2, epsilon, lr, (adamMode_t)mode, + weight_decay);)); NVTE_CHECK_CUDA(cudaGetLastError()); } @@ -812,17 +841,17 @@ void multi_tensor_adam_fp8_cuda(int chunk_size, Tensor noop_flag, g_in_type_te, g_in_type, multi_tensor_apply<5, true>( (int64_t)BLOCK_SIZE, (int64_t)chunk_size, noop_flag, tensor_lists, - AdamFunctorMaster(), stream, beta1, beta2, + AdamFunctorMaster(), stream, beta1, beta2, bias_correction1, bias_correction2, epsilon, lr, (adamMode_t)mode, weight_decay);)); } else { TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( fp8_dtype, FP8_T, TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( g_in_type_te, g_in_type, - multi_tensor_apply<5, true>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, - AdamFunctorMaster(), - stream, beta1, beta2, bias_correction1, bias_correction2, - epsilon, lr, (adamMode_t)mode, weight_decay);)); + multi_tensor_apply<5, true>( + BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, + AdamFunctorMaster(), stream, beta1, beta2, + bias_correction1, bias_correction2, epsilon, lr, (adamMode_t)mode, weight_decay);)); } NVTE_CHECK_CUDA(cudaGetLastError()); } @@ -852,22 +881,32 @@ void multi_tensor_adam_capturable_cuda(int chunk_size, Tensor noop_flag, NVTE_CHECK(tensor_lists[1][j]->dtype() == g_in_type_te, "Param tensor ", j, " has dtype=", to_string(tensor_lists[1][j]->dtype()), ", but expected dtype=", to_string(g_in_type_te)); - NVTE_CHECK(tensor_lists[2][j]->dtype() == DType::kFloat32, "First moment tensor ", j, - " has dtype=", to_string(tensor_lists[2][j]->dtype()), - ", but expected dtype=", to_string(DType::kFloat32)); - NVTE_CHECK(tensor_lists[3][j]->dtype() == DType::kFloat32, "Second moment tensor ", j, - " has dtype=", to_string(tensor_lists[3][j]->dtype()), - ", but expected dtype=", to_string(DType::kFloat32)); + { + const bool m_is_fp32 = tensor_lists[2][j]->dtype() == DType::kFloat32; + const bool m_is_bf16 = tensor_lists[2][j]->dtype() == DType::kBFloat16; + const bool v_is_fp32 = tensor_lists[3][j]->dtype() == DType::kFloat32; + const bool v_is_bf16 = tensor_lists[3][j]->dtype() == DType::kBFloat16; + NVTE_CHECK((m_is_fp32 && v_is_fp32) || (m_is_bf16 && v_is_bf16), + "First and second moment tensors must both be Float32 or both be BFloat16, but " + "tensor ", + j, " has first moment dtype=", to_string(tensor_lists[2][j]->dtype()), + " and second moment dtype=", to_string(tensor_lists[3][j]->dtype())); + } } + // Get moment dtype (m and v have the same dtype, already validated above) + const auto moment_type_te = tensor_lists[2][0]->dtype(); + // Launch kernel TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( tensor_lists[0][0]->dtype(), dtype, - multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, - AdamCapturableFunctor(), stream, beta1, beta2, - reinterpret_cast(step.data.dptr), bias_correction, epsilon, - reinterpret_cast(lr.data.dptr), (adamMode_t)mode, weight_decay, - reinterpret_cast(inv_scale.data.dptr));) + TRANSFORMER_ENGINE_TYPE_SWITCH_FP32_BF16( + moment_type_te, moment_type, + multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, + AdamCapturableFunctor(), stream, beta1, + beta2, reinterpret_cast(step.data.dptr), bias_correction, + epsilon, reinterpret_cast(lr.data.dptr), (adamMode_t)mode, + weight_decay, reinterpret_cast(inv_scale.data.dptr));)) NVTE_CHECK_CUDA(cudaGetLastError()); } @@ -897,25 +936,36 @@ void multi_tensor_adam_capturable_master_cuda(int chunk_size, Tensor noop_flag, NVTE_CHECK(tensor_lists[1][j]->dtype() == g_in_type_te, "Param tensor ", j, " has dtype=", to_string(tensor_lists[1][j]->dtype()), ", but expected dtype=", to_string(g_in_type_te)); - NVTE_CHECK(tensor_lists[2][j]->dtype() == DType::kFloat32, "First moment tensor ", j, - " has dtype=", to_string(tensor_lists[2][j]->dtype()), - ", but expected dtype=", to_string(DType::kFloat32)); - NVTE_CHECK(tensor_lists[3][j]->dtype() == DType::kFloat32, "Second moment tensor ", j, - " has dtype=", to_string(tensor_lists[3][j]->dtype()), - ", but expected dtype=", to_string(DType::kFloat32)); + { + const bool m_is_fp32 = tensor_lists[2][j]->dtype() == DType::kFloat32; + const bool m_is_bf16 = tensor_lists[2][j]->dtype() == DType::kBFloat16; + const bool v_is_fp32 = tensor_lists[3][j]->dtype() == DType::kFloat32; + const bool v_is_bf16 = tensor_lists[3][j]->dtype() == DType::kBFloat16; + NVTE_CHECK((m_is_fp32 && v_is_fp32) || (m_is_bf16 && v_is_bf16), + "First and second moment tensors must both be Float32 or both be BFloat16, but " + "tensor ", + j, " has first moment dtype=", to_string(tensor_lists[2][j]->dtype()), + " and second moment dtype=", to_string(tensor_lists[3][j]->dtype())); + } NVTE_CHECK(tensor_lists[4][j]->dtype() == DType::kFloat32, "Master param tensor ", j, " has dtype=", to_string(tensor_lists[4][j]->dtype()), ", but expected dtype=", to_string(DType::kFloat32)); } + // Get moment dtype (m and v have the same dtype, already validated above) + const auto moment_type_te = tensor_lists[2][0]->dtype(); + // Launch kernel TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( tensor_lists[0][0]->dtype(), dtype, - multi_tensor_apply<5>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, - AdamCapturableMasterFunctor(), stream, beta1, beta2, - reinterpret_cast(step.data.dptr), bias_correction, epsilon, - reinterpret_cast(lr.data.dptr), (adamMode_t)mode, weight_decay, - reinterpret_cast(inv_scale.data.dptr));) + TRANSFORMER_ENGINE_TYPE_SWITCH_FP32_BF16( + moment_type_te, moment_type, + multi_tensor_apply<5>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, + AdamCapturableMasterFunctor(), stream, + beta1, beta2, reinterpret_cast(step.data.dptr), + bias_correction, epsilon, reinterpret_cast(lr.data.dptr), + (adamMode_t)mode, weight_decay, + reinterpret_cast(inv_scale.data.dptr));)) NVTE_CHECK_CUDA(cudaGetLastError()); } diff --git a/transformer_engine/pytorch/optimizers/fused_adam.py b/transformer_engine/pytorch/optimizers/fused_adam.py index 495056d652..866f564491 100644 --- a/transformer_engine/pytorch/optimizers/fused_adam.py +++ b/transformer_engine/pytorch/optimizers/fused_adam.py @@ -207,6 +207,9 @@ def __init__( self.store_param_remainders = ( store_param_remainders and master_weights and master_weight_dtype == torch.float32 ) + # If the exp_avg and exp_avg_sq dtypes are bfloat16, we can fuse the unscaling/scaling + # operations into the fused Adam kernel. + self.fuse_unscale = self.exp_avg_dtype == self.exp_avg_sq_dtype == torch.bfloat16 # Deprecated options self.set_grad_none = set_grad_none @@ -268,10 +271,9 @@ def _apply_scale(self, state_name, unscaled_state, scaled_state, scale): dtype = self.name_to_dtype_map[state_name] if dtype == torch.uint8: assert isinstance(scaled_state, Float8Tensor) - assert len(scaled_state._quantizer.scale) == 1, ( - "Only scaling with one scaling factor per tensor is supported by the" - " FusedAdam." - ) + assert ( + len(scaled_state._quantizer.scale) == 1 + ), "Only scaling with one scaling factor per tensor is supported by the FusedAdam." else: assert scaled_state.dtype == dtype @@ -293,13 +295,22 @@ def _apply_scale(self, state_name, unscaled_state, scaled_state, scale): unscaled_state.mul_(rscale) scaled_state.copy_(unscaled_state) - def get_unscaled_state(self, param, state_name): + def get_unscaled_state( + self, param: torch.nn.Parameter, state_name: str, skip_unscale: bool = False + ) -> torch.Tensor: """Return the unscaled state corresponding to the input `param` and `state_name`. Arguments: param (torch.nn.Parameter): One of parameters in this optimizer. state_name (string): Name of optimizer states, can be one of 'exp_avg', 'exp_avg_sq', and 'master_param`. + skip_unscale (optional, bool): Whether to skip the unscaling operation. + Should only be True if 'self.fuse_unscale' is True. Default is False. + + Returns: + torch.Tensor: The unscaled state. Note that if the state is in BF16, the returned + tensor is still in BF16 because it doesn't require to be "unscaled", otherwise it + will be unscaled to FP32. """ state = self.state[param] dtype = self.name_to_dtype_map[state_name] @@ -321,7 +332,10 @@ def get_unscaled_state(self, param, state_name): unscaled = state[state_name] elif dtype == torch.bfloat16: assert state[state_name].dtype == torch.bfloat16 - unscaled = state[state_name].float() + if skip_unscale: + unscaled = state[state_name] + else: + unscaled = state[state_name].float() else: raise RuntimeError(f"Dtype of {state_name} can only be fp8/fp16/bf16/fp32.") return unscaled @@ -565,7 +579,9 @@ def step(self, closure=None, grad_scaler=None): unscaled_state[name] = self.state[p][name] assert unscaled_state[name].dtype == torch.int16 else: - unscaled = self.get_unscaled_state(p, name) + unscaled = self.get_unscaled_state( + p, name, skip_unscale=self.fuse_unscale + ) unscaled_state[name] = unscaled if self.name_to_dtype_map[name] != torch.float32: unscaled_lists[name].append(unscaled) @@ -748,6 +764,8 @@ def apply_multi_tensor_adam(adam_func, tensor_lists, inv_scale=None, out_dtype=N # Scaling for name in ["exp_avg", "exp_avg_sq", "master_param"]: + if self.fuse_unscale and name in ["exp_avg", "exp_avg_sq"]: + continue if len(unscaled_lists[name]) > 0: for unscaled, scaled, scale in zip( unscaled_lists[name], scaled_lists[name], state_scales[name] From 38388ef0fdea298eb081e32da988a85b3ad81323 Mon Sep 17 00:00:00 2001 From: Xin Yao Date: Thu, 29 Jan 2026 07:20:55 +0000 Subject: [PATCH 2/6] add more comments Signed-off-by: Xin Yao --- transformer_engine/pytorch/optimizers/fused_adam.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/transformer_engine/pytorch/optimizers/fused_adam.py b/transformer_engine/pytorch/optimizers/fused_adam.py index 866f564491..143d4adb88 100644 --- a/transformer_engine/pytorch/optimizers/fused_adam.py +++ b/transformer_engine/pytorch/optimizers/fused_adam.py @@ -765,6 +765,8 @@ def apply_multi_tensor_adam(adam_func, tensor_lists, inv_scale=None, out_dtype=N # Scaling for name in ["exp_avg", "exp_avg_sq", "master_param"]: if self.fuse_unscale and name in ["exp_avg", "exp_avg_sq"]: + # When fused_unscale is True, the scaling is fused into the Adam kernel. + # The momentums are updated inplace, so we don't need to scale here. continue if len(unscaled_lists[name]) > 0: for unscaled, scaled, scale in zip( From a608ec8840711cede7246e1a924334a6b0cc67d8 Mon Sep 17 00:00:00 2001 From: Xin Yao Date: Thu, 29 Jan 2026 07:39:41 +0000 Subject: [PATCH 3/6] enable cuda graphs for bf16 momentums Signed-off-by: Xin Yao --- .../pytorch/optimizers/fused_adam.py | 25 ++++++++++++------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/transformer_engine/pytorch/optimizers/fused_adam.py b/transformer_engine/pytorch/optimizers/fused_adam.py index 143d4adb88..715619de59 100644 --- a/transformer_engine/pytorch/optimizers/fused_adam.py +++ b/transformer_engine/pytorch/optimizers/fused_adam.py @@ -140,17 +140,24 @@ def __init__( if exp_avg_sq_dtype not in [torch.float32, torch.float16, torch.bfloat16, torch.uint8]: raise RuntimeError("FusedAdam only supports fp32/fp16/bf16/fp8 exp_avg_sq.") - # Currently, capturable mode only supports fp32 master weights and optimizer states. - # The reason is, if the master weights or optimizer states are not in fp32 dtype, - # they will be copied to temporary fp32 buffers first. These fp32 buffers are then - # used as inputs for the kernel. Consequently, the pointer for earch `.step()` differs, - # making CUDA Graph inapplicable in this scenario. + # Capturable mode requires fp32 master weights, and optimizer states (exp_avg/exp_avg_sq) + # must both be fp32 or both be bf16. This is because master weights in non-fp32 dtypes + # or optimizer states in non-fp32/bf16 dtypes require copying to temporary fp32 buffers + # before kernel execution, causing different pointers on each `.step()` call and making + # CUDA Graph inapplicable. if capturable and master_weights and master_weight_dtype != torch.float32: raise RuntimeError("Capturable mode only supports fp32 master weights.") - if capturable and exp_avg_dtype != torch.float32: - raise RuntimeError("Capturable mode only supports fp32 exp_avg.") - if capturable and exp_avg_sq_dtype != torch.float32: - raise RuntimeError("Capturable mode only supports fp32 exp_avg_sq") + if capturable: + valid_moment_dtypes = ( + exp_avg_dtype == exp_avg_sq_dtype == torch.float32 + or exp_avg_dtype == exp_avg_sq_dtype == torch.bfloat16 + ) + if not valid_moment_dtypes: + raise RuntimeError( + "Capturable mode requires exp_avg_dtype and exp_avg_sq_dtype to be " + "both torch.float32 or both torch.bfloat16, but got " + f"exp_avg_dtype={exp_avg_dtype} and exp_avg_sq_dtype={exp_avg_sq_dtype}." + ) if capturable and store_param_remainders: raise RuntimeError("Capturable mode doesn't support storing param remainders") From 893185296890fb5de6b1b2b5508f817b4dbd08fe Mon Sep 17 00:00:00 2001 From: Xin Yao Date: Thu, 29 Jan 2026 07:47:44 +0000 Subject: [PATCH 4/6] add tests Signed-off-by: Xin Yao --- tests/pytorch/test_fused_optimizer.py | 28 ++++++++++++++++++++------- 1 file changed, 21 insertions(+), 7 deletions(-) diff --git a/tests/pytorch/test_fused_optimizer.py b/tests/pytorch/test_fused_optimizer.py index f70be45918..f754fbe997 100644 --- a/tests/pytorch/test_fused_optimizer.py +++ b/tests/pytorch/test_fused_optimizer.py @@ -407,6 +407,20 @@ def test_bf16_exp_avg_sq(self): master_atol=2e-3, ) + @pytest.mark.skipif(not is_bf16_available(), reason="bf16 if not supported") + def test_bf16_exp_avg_and_exp_avg_sq(self): + self.gen_precision_aware_test( + use_fp8_params=False, + param_dtype=torch.bfloat16, + use_master_weights=True, + master_weight_dtype=torch.float32, + grad_dtype=torch.float32, + exp_avg_dtype=torch.bfloat16, + exp_avg_sq_dtype=torch.bfloat16, + master_rtol=2e-3, + master_atol=2e-3, + ) + @pytest.mark.skipif(not is_bf16_available(), reason="bf16 if not supported") @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) def test_fp8_exp_avg_sq(self): @@ -553,7 +567,7 @@ def forward(self, x): return y -class AdamTest: +class TestAdamTest: def setup_method(self, *, seed: int = 0) -> None: torch.manual_seed(seed) @@ -569,8 +583,8 @@ def setup_method(self, *, seed: int = 0) -> None: def test_grad_scaler(self): params_ = [p for p in self.model_.parameters() if p.requires_grad] optimizer_ = te.optimizers.FusedAdam(params_, lr=self.lr, capturable=False) - scaler = torch.cuda.amp.GradScaler(enabled=True) - scaler_ = torch.cuda.amp.GradScaler(enabled=True) + scaler = torch.amp.GradScaler('cuda',enabled=True) + scaler_ = torch.amp.GradScaler('cuda',enabled=True) for i in range(100): x = torch.rand([32, 1, 28, 28]).cuda().to(memory_format=torch.channels_last) @@ -620,8 +634,8 @@ def test_grad_scaler(self): def test_grad_scaler_capturable(self): params_ = [p for p in self.model_.parameters() if p.requires_grad] optimizer_ = te.optimizers.FusedAdam(params_, lr=self.lr, capturable=True) - scaler = torch.cuda.amp.GradScaler(enabled=True) - scaler_ = torch.cuda.amp.GradScaler(enabled=True) + scaler = torch.amp.GradScaler('cuda',enabled=True) + scaler_ = torch.amp.GradScaler('cuda',enabled=True) for i in range(100): x = torch.rand([32, 1, 28, 28]).cuda().to(memory_format=torch.channels_last) @@ -678,8 +692,8 @@ def test_grad_scaler_capturable_master(self): optimizer_ = te.optimizers.FusedAdam( params_, lr=self.lr, capturable=True, master_weights=master_weights ) - scaler = torch.cuda.amp.GradScaler(enabled=True) - scaler_ = torch.cuda.amp.GradScaler(enabled=True) + scaler = torch.amp.GradScaler('cuda',enabled=True) + scaler_ = torch.amp.GradScaler('cuda',enabled=True) for i in range(100): x = torch.rand([32, 1, 28, 28]).cuda().to(memory_format=torch.channels_last) From b3b419ce8db50f55843b02034dbacd9d0b4061bd Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 29 Jan 2026 07:49:41 +0000 Subject: [PATCH 5/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/pytorch/test_fused_optimizer.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/pytorch/test_fused_optimizer.py b/tests/pytorch/test_fused_optimizer.py index f754fbe997..185b9b85bc 100644 --- a/tests/pytorch/test_fused_optimizer.py +++ b/tests/pytorch/test_fused_optimizer.py @@ -583,8 +583,8 @@ def setup_method(self, *, seed: int = 0) -> None: def test_grad_scaler(self): params_ = [p for p in self.model_.parameters() if p.requires_grad] optimizer_ = te.optimizers.FusedAdam(params_, lr=self.lr, capturable=False) - scaler = torch.amp.GradScaler('cuda',enabled=True) - scaler_ = torch.amp.GradScaler('cuda',enabled=True) + scaler = torch.amp.GradScaler("cuda", enabled=True) + scaler_ = torch.amp.GradScaler("cuda", enabled=True) for i in range(100): x = torch.rand([32, 1, 28, 28]).cuda().to(memory_format=torch.channels_last) @@ -634,8 +634,8 @@ def test_grad_scaler(self): def test_grad_scaler_capturable(self): params_ = [p for p in self.model_.parameters() if p.requires_grad] optimizer_ = te.optimizers.FusedAdam(params_, lr=self.lr, capturable=True) - scaler = torch.amp.GradScaler('cuda',enabled=True) - scaler_ = torch.amp.GradScaler('cuda',enabled=True) + scaler = torch.amp.GradScaler("cuda", enabled=True) + scaler_ = torch.amp.GradScaler("cuda", enabled=True) for i in range(100): x = torch.rand([32, 1, 28, 28]).cuda().to(memory_format=torch.channels_last) @@ -692,8 +692,8 @@ def test_grad_scaler_capturable_master(self): optimizer_ = te.optimizers.FusedAdam( params_, lr=self.lr, capturable=True, master_weights=master_weights ) - scaler = torch.amp.GradScaler('cuda',enabled=True) - scaler_ = torch.amp.GradScaler('cuda',enabled=True) + scaler = torch.amp.GradScaler("cuda", enabled=True) + scaler_ = torch.amp.GradScaler("cuda", enabled=True) for i in range(100): x = torch.rand([32, 1, 28, 28]).cuda().to(memory_format=torch.channels_last) From 19ff141e7e0e08f0315f35f5405faf3d9e43b55a Mon Sep 17 00:00:00 2001 From: Xin Yao Date: Fri, 30 Jan 2026 04:38:23 +0000 Subject: [PATCH 6/6] update the check for store_param_remainders and capturable Signed-off-by: Xin Yao --- transformer_engine/pytorch/optimizers/fused_adam.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/optimizers/fused_adam.py b/transformer_engine/pytorch/optimizers/fused_adam.py index 715619de59..a87d968334 100644 --- a/transformer_engine/pytorch/optimizers/fused_adam.py +++ b/transformer_engine/pytorch/optimizers/fused_adam.py @@ -158,8 +158,6 @@ def __init__( "both torch.float32 or both torch.bfloat16, but got " f"exp_avg_dtype={exp_avg_dtype} and exp_avg_sq_dtype={exp_avg_sq_dtype}." ) - if capturable and store_param_remainders: - raise RuntimeError("Capturable mode doesn't support storing param remainders") # If the optimizer is capturable then LR should be a tensor (on GPU) lr = torch.tensor(lr, dtype=torch.float32) if capturable else lr @@ -214,6 +212,8 @@ def __init__( self.store_param_remainders = ( store_param_remainders and master_weights and master_weight_dtype == torch.float32 ) + if self.capturable and self.store_param_remainders: + raise RuntimeError("Capturable mode doesn't support storing param remainders") # If the exp_avg and exp_avg_sq dtypes are bfloat16, we can fuse the unscaling/scaling # operations into the fused Adam kernel. self.fuse_unscale = self.exp_avg_dtype == self.exp_avg_sq_dtype == torch.bfloat16