From 83025fc8c0fa21e1737cee6b20fd36010ee63d92 Mon Sep 17 00:00:00 2001 From: Oleg Goncharov Date: Tue, 5 May 2026 18:01:03 +0000 Subject: [PATCH 1/7] Use fast unfused cast mxfp8 kernels by default Signed-off-by: Oleg Goncharov --- .../common/cast/mxfp8/quantize_mxfp8.cuh | 2 +- .../cast/mxfp8/specialized/quantize_mxfp8.cuh | 20 ++++--------------- 2 files changed, 5 insertions(+), 17 deletions(-) diff --git a/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh index a0ae7dde82..76161befbe 100644 --- a/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh @@ -644,7 +644,7 @@ void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, with_gemm_swizzled_scales, WITH_GEMM_SWIZZLED_SCALES, if (specialized::hasSpec() && - !WITH_GEMM_SWIZZLED_SCALES) { + !WITH_GEMM_SWIZZLED_SCALES && (scaling_type != ScalingType::COLWISE)) { switch (scaling_type) { case ScalingType::ROWWISE: { using traits = specialized::CastTraits; diff --git a/transformer_engine/common/cast/mxfp8/specialized/quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/specialized/quantize_mxfp8.cuh index 41e62ac319..9459f0273a 100644 --- a/transformer_engine/common/cast/mxfp8/specialized/quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/specialized/quantize_mxfp8.cuh @@ -91,18 +91,6 @@ __device__ __forceinline__ e8m0_t to_e8m0(IType amax) { #endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) } // anonymous namespace -inline bool is_cast_only_enabled() { - static bool enabled = []() { - const char *env = std::getenv("ENABLE_CAST_ONLY"); - return env != nullptr && (env[0] == '1'); - }(); - return enabled; - - // // FIXME: when finish debugging, remove this - // const char* env = std::getenv("ENABLE_CAST_ONLY"); - // return env != nullptr && (env[0] == '1'); -} - template inline bool hasSpec() { return false; @@ -112,19 +100,19 @@ inline bool hasSpec() { // OType could be [fp8e5m2, fp8e4m3] template <> inline bool hasSpec() { - return is_cast_only_enabled(); + return true; } template <> inline bool hasSpec() { - return is_cast_only_enabled(); + return true; } template <> inline bool hasSpec() { - return is_cast_only_enabled(); + return true; } template <> inline bool hasSpec() { - return is_cast_only_enabled(); + return true; } template From e926c9a6542ce0c8469cd73b4a99033221a8983b Mon Sep 17 00:00:00 2001 From: Oleg Goncharov Date: Tue, 5 May 2026 18:07:57 +0000 Subject: [PATCH 2/7] Removed dead code Signed-off-by: Oleg Goncharov --- transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh | 4 ---- 1 file changed, 4 deletions(-) diff --git a/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh index 76161befbe..bdd063f4f8 100644 --- a/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh @@ -664,10 +664,6 @@ void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, break; } - case ScalingType::COLWISE: { - NVTE_WARN("Colwise scaling will fallback to original kernel."); - break; - } case ScalingType::BIDIMENSIONAL: { using traits = specialized::CastTraits; auto kernel = specialized::quantize_mxfp8_kernel_cast_only; From 9ae6664d85a4d076216ea658965afdd4bfbe18e2 Mon Sep 17 00:00:00 2001 From: Oleg Goncharov Date: Wed, 6 May 2026 12:42:23 +0000 Subject: [PATCH 3/7] Use fast kernel for full 32-element chunks only Signed-off-by: Oleg Goncharov --- .../common/cast/mxfp8/quantize_mxfp8.cuh | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh index bdd063f4f8..91560e8698 100644 --- a/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh @@ -643,8 +643,18 @@ void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, TRANSFORMER_ENGINE_SWITCH_CONDITION( with_gemm_swizzled_scales, WITH_GEMM_SWIZZLED_SCALES, + // The specialized rowwise cast-only kernel vectorizes full 32-element chunks. + // Shapes with a partial row tail (for example, N=48) must use the generic kernel, + // otherwise the last chunk reads/writes past the logical end of the row. + const bool is_full_rowwise_chunk = + (cols % specialized::CastTraits::chunkElems == 0); + + const bool scaling_type_has_specialized_support = + (scaling_type == ScalingType::ROWWISE && is_full_rowwise_chunk) || + (scaling_type == ScalingType::BIDIMENSIONAL); + if (specialized::hasSpec() && - !WITH_GEMM_SWIZZLED_SCALES && (scaling_type != ScalingType::COLWISE)) { + !WITH_GEMM_SWIZZLED_SCALES && scaling_type_has_specialized_support) { switch (scaling_type) { case ScalingType::ROWWISE: { using traits = specialized::CastTraits; From 717038527c6d2dd5b64ca450b606538f38fa8ef3 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 6 May 2026 12:43:22 +0000 Subject: [PATCH 4/7] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh index 91560e8698..a0ac98a028 100644 --- a/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh @@ -647,11 +647,11 @@ void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, // Shapes with a partial row tail (for example, N=48) must use the generic kernel, // otherwise the last chunk reads/writes past the logical end of the row. const bool is_full_rowwise_chunk = - (cols % specialized::CastTraits::chunkElems == 0); + (cols % specialized::CastTraits::chunkElems == 0); - const bool scaling_type_has_specialized_support = - (scaling_type == ScalingType::ROWWISE && is_full_rowwise_chunk) || - (scaling_type == ScalingType::BIDIMENSIONAL); + const bool scaling_type_has_specialized_support = + (scaling_type == ScalingType::ROWWISE && is_full_rowwise_chunk) || + (scaling_type == ScalingType::BIDIMENSIONAL); if (specialized::hasSpec() && !WITH_GEMM_SWIZZLED_SCALES && scaling_type_has_specialized_support) { From 9956d3a2f0310d94fcedb5f1c595f3f889798d04 Mon Sep 17 00:00:00 2001 From: Oleg Goncharov Date: Thu, 7 May 2026 14:25:24 +0000 Subject: [PATCH 5/7] Fix Signed-off-by: Oleg Goncharov --- transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh index a0ac98a028..f42ee6cdd3 100644 --- a/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh @@ -643,11 +643,10 @@ void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, TRANSFORMER_ENGINE_SWITCH_CONDITION( with_gemm_swizzled_scales, WITH_GEMM_SWIZZLED_SCALES, - // The specialized rowwise cast-only kernel vectorizes full 32-element chunks. + // The specialized rowwise cast-only kernel vectorizes full 128-element chunks. // Shapes with a partial row tail (for example, N=48) must use the generic kernel, // otherwise the last chunk reads/writes past the logical end of the row. - const bool is_full_rowwise_chunk = - (cols % specialized::CastTraits::chunkElems == 0); + const bool is_full_rowwise_chunk = (cols % 128 == 0); const bool scaling_type_has_specialized_support = (scaling_type == ScalingType::ROWWISE && is_full_rowwise_chunk) || From 3a01e531b6ad19cec73e287a0c9c0126718b513c Mon Sep 17 00:00:00 2001 From: Oleg Goncharov Date: Fri, 8 May 2026 17:54:33 +0000 Subject: [PATCH 6/7] Fixed grid size overflow Signed-off-by: Oleg Goncharov --- .../common/cast/mxfp8/quantize_mxfp8.cuh | 28 ++++++++++++------- 1 file changed, 18 insertions(+), 10 deletions(-) diff --git a/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh index f42ee6cdd3..409d6a6e60 100644 --- a/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh @@ -646,11 +646,20 @@ void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, // The specialized rowwise cast-only kernel vectorizes full 128-element chunks. // Shapes with a partial row tail (for example, N=48) must use the generic kernel, // otherwise the last chunk reads/writes past the logical end of the row. - const bool is_full_rowwise_chunk = (cols % 128 == 0); + using rowwise_traits = specialized::CastTraits; + using bidimensional_traits = specialized::CastTraits; + constexpr size_t max_grid_dim_y = 65535; + const bool rowwise_specialized_grid_fits = + ((rows + rowwise_traits::blockDimM - 1) / rowwise_traits::blockDimM) <= max_grid_dim_y; + const bool bidimensional_specialized_grid_fits = + ((rows + bidimensional_traits::blockDIM::M - 1) / bidimensional_traits::blockDIM::M) <= max_grid_dim_y; + const bool is_full_rowwise_chunk = (cols % 128 == 0); const bool scaling_type_has_specialized_support = - (scaling_type == ScalingType::ROWWISE && is_full_rowwise_chunk) || - (scaling_type == ScalingType::BIDIMENSIONAL); + (scaling_type == ScalingType::ROWWISE && is_full_rowwise_chunk && + rowwise_specialized_grid_fits) || + (scaling_type == ScalingType::BIDIMENSIONAL && + bidimensional_specialized_grid_fits); if (specialized::hasSpec() && !WITH_GEMM_SWIZZLED_SCALES && scaling_type_has_specialized_support) { @@ -659,8 +668,8 @@ void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, using traits = specialized::CastTraits; auto kernel = specialized::quantize_mxfp8_kernel_cast_only; - cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, - traits::smem); + NVTE_CHECK_CUDA(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, traits::smem)); dim3 block(traits::threadLayout::num, traits::warpLayout::N, traits::warpLayout::M); @@ -677,8 +686,8 @@ void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, using traits = specialized::CastTraits; auto kernel = specialized::quantize_mxfp8_kernel_cast_only; - cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, - traits::smem); + NVTE_CHECK_CUDA(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, traits::smem)); // TMA for loading, so that we don't need STS for transposing alignas(64) CUtensorMap tensor_map_input{}; constexpr size_t input_type_bit_size = TypeInfo::size; @@ -715,6 +724,7 @@ void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, NVTE_ERROR("Invalid scaling type."); } } + NVTE_CHECK_CUDA(cudaGetLastError()); return; } @@ -794,7 +804,6 @@ void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr, workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise, scale_stride_colwise); - NVTE_CHECK_CUDA(cudaGetLastError()); break; } case ScalingType::COLWISE: { @@ -809,7 +818,6 @@ void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr, workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise, scale_stride_colwise); - NVTE_CHECK_CUDA(cudaGetLastError()); break; } case ScalingType::BIDIMENSIONAL: { @@ -824,10 +832,10 @@ void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr, workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise, scale_stride_colwise); - NVTE_CHECK_CUDA(cudaGetLastError()); break; } } + NVTE_CHECK_CUDA(cudaGetLastError()); if constexpr (IS_DBIAS) { common::reduce_dbias(workspace_ptr, dbias, dbias_rows, dbias_cols, stream); From d4b2b8147ba8e12f5281b5f0e49f161a34f711b2 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 8 May 2026 17:55:34 +0000 Subject: [PATCH 7/7] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh index 409d6a6e60..1549a292d8 100644 --- a/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh @@ -650,9 +650,11 @@ void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, using bidimensional_traits = specialized::CastTraits; constexpr size_t max_grid_dim_y = 65535; const bool rowwise_specialized_grid_fits = - ((rows + rowwise_traits::blockDimM - 1) / rowwise_traits::blockDimM) <= max_grid_dim_y; + ((rows + rowwise_traits::blockDimM - 1) / rowwise_traits::blockDimM) <= + max_grid_dim_y; const bool bidimensional_specialized_grid_fits = - ((rows + bidimensional_traits::blockDIM::M - 1) / bidimensional_traits::blockDIM::M) <= max_grid_dim_y; + ((rows + bidimensional_traits::blockDIM::M - 1) / + bidimensional_traits::blockDIM::M) <= max_grid_dim_y; const bool is_full_rowwise_chunk = (cols % 128 == 0); const bool scaling_type_has_specialized_support = @@ -834,8 +836,7 @@ void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, scale_stride_colwise); break; } - } - NVTE_CHECK_CUDA(cudaGetLastError()); + } NVTE_CHECK_CUDA(cudaGetLastError()); if constexpr (IS_DBIAS) { common::reduce_dbias(workspace_ptr, dbias, dbias_rows, dbias_cols, stream);