Skip to content
Merged
40 changes: 27 additions & 13 deletions transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -643,15 +643,35 @@ void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop,
TRANSFORMER_ENGINE_SWITCH_CONDITION(
with_gemm_swizzled_scales, WITH_GEMM_SWIZZLED_SCALES,

// The specialized rowwise cast-only kernel vectorizes full 128-element chunks.
// Shapes with a partial row tail (for example, N=48) must use the generic kernel,
// otherwise the last chunk reads/writes past the logical end of the row.
using rowwise_traits = specialized::CastTraits<IType, OType, true, false>;
using bidimensional_traits = specialized::CastTraits<IType, OType, true, true>;
constexpr size_t max_grid_dim_y = 65535;
const bool rowwise_specialized_grid_fits =
((rows + rowwise_traits::blockDimM - 1) / rowwise_traits::blockDimM) <=
max_grid_dim_y;
const bool bidimensional_specialized_grid_fits =
((rows + bidimensional_traits::blockDIM::M - 1) /
bidimensional_traits::blockDIM::M) <= max_grid_dim_y;

const bool is_full_rowwise_chunk = (cols % 128 == 0);
const bool scaling_type_has_specialized_support =
(scaling_type == ScalingType::ROWWISE && is_full_rowwise_chunk &&
rowwise_specialized_grid_fits) ||
(scaling_type == ScalingType::BIDIMENSIONAL &&
bidimensional_specialized_grid_fits);

if (specialized::hasSpec<IS_DBIAS, IS_DACT, IS_ACT, IType, OType>() &&
!WITH_GEMM_SWIZZLED_SCALES) {
!WITH_GEMM_SWIZZLED_SCALES && scaling_type_has_specialized_support) {
switch (scaling_type) {
case ScalingType::ROWWISE: {
using traits = specialized::CastTraits<IType, OType, true, false>;
auto kernel = specialized::quantize_mxfp8_kernel_cast_only<traits>;

cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize,
traits::smem);
NVTE_CHECK_CUDA(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, traits::smem));

dim3 block(traits::threadLayout::num, traits::warpLayout::N,
traits::warpLayout::M);
Expand All @@ -664,16 +684,12 @@ void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop,

break;
}
case ScalingType::COLWISE: {
NVTE_WARN("Colwise scaling will fallback to original kernel.");
break;
}
case ScalingType::BIDIMENSIONAL: {
using traits = specialized::CastTraits<IType, OType, true, true>;
auto kernel = specialized::quantize_mxfp8_kernel_cast_only<traits>;

cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize,
traits::smem);
NVTE_CHECK_CUDA(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, traits::smem));
// TMA for loading, so that we don't need STS for transposing
alignas(64) CUtensorMap tensor_map_input{};
constexpr size_t input_type_bit_size = TypeInfo<IType>::size;
Expand Down Expand Up @@ -710,6 +726,7 @@ void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop,
NVTE_ERROR("Invalid scaling type.");
}
}
NVTE_CHECK_CUDA(cudaGetLastError());
return;
}

Expand Down Expand Up @@ -789,7 +806,6 @@ void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop,
tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr,
workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise,
scale_stride_colwise);
NVTE_CHECK_CUDA(cudaGetLastError());
break;
}
case ScalingType::COLWISE: {
Expand All @@ -804,7 +820,6 @@ void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop,
tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr,
workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise,
scale_stride_colwise);
NVTE_CHECK_CUDA(cudaGetLastError());
break;
}
case ScalingType::BIDIMENSIONAL: {
Expand All @@ -819,10 +834,9 @@ void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop,
tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr,
workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise,
scale_stride_colwise);
NVTE_CHECK_CUDA(cudaGetLastError());
break;
}
}
} NVTE_CHECK_CUDA(cudaGetLastError());

if constexpr (IS_DBIAS) {
common::reduce_dbias<IType>(workspace_ptr, dbias, dbias_rows, dbias_cols, stream);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,18 +91,6 @@ __device__ __forceinline__ e8m0_t to_e8m0(IType amax) {
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
} // anonymous namespace

inline bool is_cast_only_enabled() {
static bool enabled = []() {
const char *env = std::getenv("ENABLE_CAST_ONLY");
return env != nullptr && (env[0] == '1');
}();
return enabled;

// // FIXME: when finish debugging, remove this
// const char* env = std::getenv("ENABLE_CAST_ONLY");
// return env != nullptr && (env[0] == '1');
}

template <bool IS_DBIAS, bool IS_DACT, bool IS_ACT, typename IType, typename OType>
inline bool hasSpec() {
return false;
Expand All @@ -112,19 +100,19 @@ inline bool hasSpec() {
// OType could be [fp8e5m2, fp8e4m3]
template <>
inline bool hasSpec<false, false, false, fp16, fp8e5m2>() {
return is_cast_only_enabled();
return true;
}
template <>
inline bool hasSpec<false, false, false, fp16, fp8e4m3>() {
return is_cast_only_enabled();
return true;
}
template <>
inline bool hasSpec<false, false, false, bf16, fp8e5m2>() {
return is_cast_only_enabled();
return true;
}
template <>
inline bool hasSpec<false, false, false, bf16, fp8e4m3>() {
return is_cast_only_enabled();
return true;
}

template <int32_t _M, int32_t _N>
Expand Down
Loading