diff --git a/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh index a0ae7dde82..1549a292d8 100644 --- a/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh @@ -643,15 +643,35 @@ void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, TRANSFORMER_ENGINE_SWITCH_CONDITION( with_gemm_swizzled_scales, WITH_GEMM_SWIZZLED_SCALES, + // The specialized rowwise cast-only kernel vectorizes full 128-element chunks. + // Shapes with a partial row tail (for example, N=48) must use the generic kernel, + // otherwise the last chunk reads/writes past the logical end of the row. + using rowwise_traits = specialized::CastTraits; + using bidimensional_traits = specialized::CastTraits; + constexpr size_t max_grid_dim_y = 65535; + const bool rowwise_specialized_grid_fits = + ((rows + rowwise_traits::blockDimM - 1) / rowwise_traits::blockDimM) <= + max_grid_dim_y; + const bool bidimensional_specialized_grid_fits = + ((rows + bidimensional_traits::blockDIM::M - 1) / + bidimensional_traits::blockDIM::M) <= max_grid_dim_y; + + const bool is_full_rowwise_chunk = (cols % 128 == 0); + const bool scaling_type_has_specialized_support = + (scaling_type == ScalingType::ROWWISE && is_full_rowwise_chunk && + rowwise_specialized_grid_fits) || + (scaling_type == ScalingType::BIDIMENSIONAL && + bidimensional_specialized_grid_fits); + if (specialized::hasSpec() && - !WITH_GEMM_SWIZZLED_SCALES) { + !WITH_GEMM_SWIZZLED_SCALES && scaling_type_has_specialized_support) { switch (scaling_type) { case ScalingType::ROWWISE: { using traits = specialized::CastTraits; auto kernel = specialized::quantize_mxfp8_kernel_cast_only; - cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, - traits::smem); + NVTE_CHECK_CUDA(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, traits::smem)); dim3 block(traits::threadLayout::num, traits::warpLayout::N, traits::warpLayout::M); @@ -664,16 +684,12 @@ void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, break; } - case ScalingType::COLWISE: { - NVTE_WARN("Colwise scaling will fallback to original kernel."); - break; - } case ScalingType::BIDIMENSIONAL: { using traits = specialized::CastTraits; auto kernel = specialized::quantize_mxfp8_kernel_cast_only; - cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, - traits::smem); + NVTE_CHECK_CUDA(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, traits::smem)); // TMA for loading, so that we don't need STS for transposing alignas(64) CUtensorMap tensor_map_input{}; constexpr size_t input_type_bit_size = TypeInfo::size; @@ -710,6 +726,7 @@ void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, NVTE_ERROR("Invalid scaling type."); } } + NVTE_CHECK_CUDA(cudaGetLastError()); return; } @@ -789,7 +806,6 @@ void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr, workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise, scale_stride_colwise); - NVTE_CHECK_CUDA(cudaGetLastError()); break; } case ScalingType::COLWISE: { @@ -804,7 +820,6 @@ void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr, workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise, scale_stride_colwise); - NVTE_CHECK_CUDA(cudaGetLastError()); break; } case ScalingType::BIDIMENSIONAL: { @@ -819,10 +834,9 @@ void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr, workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise, scale_stride_colwise); - NVTE_CHECK_CUDA(cudaGetLastError()); break; } - } + } NVTE_CHECK_CUDA(cudaGetLastError()); if constexpr (IS_DBIAS) { common::reduce_dbias(workspace_ptr, dbias, dbias_rows, dbias_cols, stream); diff --git a/transformer_engine/common/cast/mxfp8/specialized/quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/specialized/quantize_mxfp8.cuh index 41e62ac319..9459f0273a 100644 --- a/transformer_engine/common/cast/mxfp8/specialized/quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/specialized/quantize_mxfp8.cuh @@ -91,18 +91,6 @@ __device__ __forceinline__ e8m0_t to_e8m0(IType amax) { #endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) } // anonymous namespace -inline bool is_cast_only_enabled() { - static bool enabled = []() { - const char *env = std::getenv("ENABLE_CAST_ONLY"); - return env != nullptr && (env[0] == '1'); - }(); - return enabled; - - // // FIXME: when finish debugging, remove this - // const char* env = std::getenv("ENABLE_CAST_ONLY"); - // return env != nullptr && (env[0] == '1'); -} - template inline bool hasSpec() { return false; @@ -112,19 +100,19 @@ inline bool hasSpec() { // OType could be [fp8e5m2, fp8e4m3] template <> inline bool hasSpec() { - return is_cast_only_enabled(); + return true; } template <> inline bool hasSpec() { - return is_cast_only_enabled(); + return true; } template <> inline bool hasSpec() { - return is_cast_only_enabled(); + return true; } template <> inline bool hasSpec() { - return is_cast_only_enabled(); + return true; } template