Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tests/cpp/operator/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ add_executable(test_operator
test_cast_mxfp8_grouped.cu
test_cast_nvfp4_transpose.cu
test_cast_float8blockwise.cu
test_cast_float8blockwise_grouped.cu
test_dequantize_mxfp8.cu
test_dequantize_mxfp8_grouped.cu
test_dequantize_nvfp4.cu
Expand Down
449 changes: 449 additions & 0 deletions tests/cpp/operator/test_cast_float8blockwise_grouped.cu

Large diffs are not rendered by default.

31 changes: 31 additions & 0 deletions transformer_engine/common/cast/dispatch/quantize.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "../../util/vectorized_pointwise.h"
#include "../core/common.cuh"
#include "../fp8/quantize_fp8.cuh"
#include "../fp8_blockwise/group_quantize_fp8_blockwise.cuh"
#include "../mxfp8/group_quantize_mxfp8.cuh"
#include "../mxfp8/quantize_mxfp8.cuh"
#include "../nvfp4/group_quantize_transpose_nvfp4.cuh"
Expand Down Expand Up @@ -466,6 +467,20 @@ void group_quantize_fwd_helper(const NVTEGroupedTensor input, NVTEGroupedTensor
workspace_tensor, &quant_config_cpp, stream);
break;
}
case NVTE_BLOCK_SCALING_1D: {
NVTE_CHECK(!IS_ACT, "IS_ACT is not implemented for grouped NVTE_BLOCK_SCALING_1D.");
fp8_blockwise::group_quantize_blockwise_1d(input_tensor, output_tensor, noop_tensor,
quant_config_cpp.amax_epsilon,
quant_config_cpp.force_pow_2_scales, stream);
break;
}
case NVTE_BLOCK_SCALING_2D: {
NVTE_CHECK(!IS_ACT, "IS_ACT is not implemented for grouped NVTE_BLOCK_SCALING_2D.");
fp8_blockwise::group_quantize_blockwise_2d(input_tensor, output_tensor, noop_tensor,
quant_config_cpp.amax_epsilon,
quant_config_cpp.force_pow_2_scales, stream);
break;
}
default:
NVTE_ERROR("Not implemented scaling mode: " + to_string(scaling_mode) + ".");
}
Expand Down Expand Up @@ -507,6 +522,22 @@ void group_quantize_bwd_helper(const NVTEGroupedTensor grad, const NVTEGroupedTe
&quant_config_cpp, stream);
break;
}
case NVTE_BLOCK_SCALING_1D: {
NVTE_CHECK((!IS_DBIAS && !IS_DACT),
"IS_DBIAS and IS_DACT are not implemented for grouped NVTE_BLOCK_SCALING_1D.");
fp8_blockwise::group_quantize_blockwise_1d(grad_tensor, output_tensor, noop_tensor,
quant_config_cpp.amax_epsilon,
quant_config_cpp.force_pow_2_scales, stream);
break;
}
case NVTE_BLOCK_SCALING_2D: {
NVTE_CHECK((!IS_DBIAS && !IS_DACT),
"IS_DBIAS and IS_DACT are not implemented for grouped NVTE_BLOCK_SCALING_2D.");
fp8_blockwise::group_quantize_blockwise_2d(grad_tensor, output_tensor, noop_tensor,
quant_config_cpp.amax_epsilon,
quant_config_cpp.force_pow_2_scales, stream);
break;
}
default:
NVTE_ERROR("Not implemented scaling mode: " + to_string(scaling_mode) + ".");
}
Expand Down

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions transformer_engine/common/cast/mxfp8/gated_mxfp8.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
#include "../../util/math.h"
#include "../../util/ptx.cuh"
#include "../../utils.cuh"
#include "swizzle.cuh"
#include "../swizzle.cuh"

namespace transformer_engine {
namespace dispatch {
Expand Down Expand Up @@ -70,7 +70,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
using IType2 = typename ptx::FPx2<IType>;
using OType2 = typename ptx::FPx2<OType>;

using transformer_engine::dispatch::mxfp8::swizzle::gemm_swizzled_scale_idx;
using transformer_engine::dispatch::swizzle::gemm_swizzled_scale_idx;

constexpr size_t STAGES = CHUNK_DIM_Y / BUFF_DIM_Y;
static_assert(STAGES >= 1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
#include "../../util/ptx.cuh"
#include "../../utils.cuh"
#include "../core/common.cuh"
#include "swizzle.cuh"
#include "../swizzle.cuh"

namespace transformer_engine {
namespace dispatch {
Expand Down Expand Up @@ -120,7 +120,7 @@ __device__ __forceinline__ void process_colwise_stage(
const size_t tensor_scales_offset_colwise_base = tensor_base_row * cols_padded / SCALE_DIM_Y;
const size_t local_scales_offset_Y = global_scales_offset_Y - tensor_scales_offset_Y_base;
scale_idx = tensor_scales_offset_colwise_base +
transformer_engine::dispatch::mxfp8::swizzle::gemm_swizzled_scale_idx(
transformer_engine::dispatch::swizzle::gemm_swizzled_scale_idx(
global_scales_offset_X, local_scales_offset_Y,
DIVUP(rows, static_cast<size_t>(scale_tensor_alignment_Y_rowwise)));
} else {
Expand Down Expand Up @@ -395,7 +395,7 @@ __device__ __forceinline__ void process_rowwise_stage(

size_t scale_idx = 0;
if constexpr (WITH_GEMM_SWIZZLED_SCALES) {
scale_idx = transformer_engine::dispatch::mxfp8::swizzle::gemm_swizzled_scale_idx(
scale_idx = transformer_engine::dispatch::swizzle::gemm_swizzled_scale_idx(
stage_scales_offset_Y, stage_scales_offset_X,
DIVUP(cols, static_cast<size_t>(scale_tensor_alignment_X_colwise)));
} else {
Expand Down
2 changes: 1 addition & 1 deletion transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
using IType2 = typename ptx::FPx2<IType>;
using OType2 = typename ptx::FPx2<OType>;

using transformer_engine::dispatch::mxfp8::swizzle::gemm_swizzled_scale_idx;
using transformer_engine::dispatch::swizzle::gemm_swizzled_scale_idx;

if constexpr (NO_ACTIVATIONS) {
if (noop != nullptr && noop[0] == 1.0f) {
Expand Down
4 changes: 2 additions & 2 deletions transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
#include "../../util/math.h"
#include "../../util/ptx.cuh"
#include "../../utils.cuh"
#include "../mxfp8/swizzle.cuh"
#include "../swizzle.cuh"

#if FP4_TYPE_SUPPORTED
#include <cuda_fp4.h>
Expand Down Expand Up @@ -55,7 +55,7 @@ __global__ void __launch_bounds__(512)
const size_t my_index = x + y * M;
size_t my_scale_index;
if constexpr (WITH_GEMM_SWIZZLED_SCALES) {
my_scale_index = mxfp8::swizzle::gemm_swizzled_scale_idx(y, x, num_scale_tiles_X);
my_scale_index = swizzle::gemm_swizzled_scale_idx(y, x, num_scale_tiles_X);
} else {
my_scale_index = x + y * scale_stride;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,24 @@
************************************************************************/

/*! \file swizzle.cuh
* \brief Helper function for GEMM-swizzled scales
* \brief Helper function for GEMM-swizzled scales. Shared across MXFP8,
* NVFP4, and FP8 block-scaling paths.
*/

#ifndef TRANSFORMER_ENGINE_COMMON_CAST_MXFP8_SWIZZLE_CUH_
#define TRANSFORMER_ENGINE_COMMON_CAST_MXFP8_SWIZZLE_CUH_
#ifndef TRANSFORMER_ENGINE_COMMON_CAST_SWIZZLE_CUH_
#define TRANSFORMER_ENGINE_COMMON_CAST_SWIZZLE_CUH_

namespace transformer_engine {
namespace dispatch {
namespace mxfp8 {
namespace swizzle {

constexpr size_t GEMM_SWIZZLED_SCALE_TILE_DIM_X = 4;
constexpr size_t GEMM_SWIZZLED_SCALE_TILE_DIM_Y = 128;

/*! \brief Convert compact scale indices into GEMM swizzled scale index
*
* MXFP8 GEMM expects scaling factors to be in a "swizzled" order
* MXFP8 / NVFP4 / FP8 block-scaling GEMM expects scaling factors to be in a
* "swizzled" order
* (https://docs.nvidia.com/cuda/cublas/#d-block-scaling-factors-layout).
* This function converts indices from "compact" order (i.e. matching
* the FP8 data) to swizzled order.
Expand All @@ -41,8 +42,7 @@ __device__ __forceinline__ size_t gemm_swizzled_scale_idx(size_t i, size_t j, si
}

} // namespace swizzle
} // namespace mxfp8
} // namespace dispatch
} // namespace transformer_engine

#endif // TRANSFORMER_ENGINE_COMMON_CAST_MXFP8_SWIZZLE_CUH_
#endif // TRANSFORMER_ENGINE_COMMON_CAST_SWIZZLE_CUH_
14 changes: 7 additions & 7 deletions transformer_engine/common/gemm/cublaslt_grouped_gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
#include <type_traits>
#include <vector>

#include "../cast/mxfp8/swizzle.cuh"
#include "../cast/swizzle.cuh"
#include "../common.h"
#include "../util/cuda_runtime.h"
#include "../util/handle_manager.h"
Expand Down Expand Up @@ -1104,11 +1104,11 @@ __forceinline__ __device__ int64_t compute_grouped_tensor_offset(const TensorSha

__forceinline__ __device__ int64_t padded_mxfp8_scale_inv_bytes(int64_t first, int64_t last,
bool rowwise) {
namespace mxfp8_swizzle = transformer_engine::dispatch::mxfp8::swizzle;
namespace gemm_swizzle = transformer_engine::dispatch::swizzle;
constexpr int64_t kMxfp8BlockSize = 32;
// x is the dimension along which quantization is applied, y is other dimension
const int64_t scale_tile_y = static_cast<int64_t>(mxfp8_swizzle::GEMM_SWIZZLED_SCALE_TILE_DIM_Y);
const int64_t scale_tile_x = static_cast<int64_t>(mxfp8_swizzle::GEMM_SWIZZLED_SCALE_TILE_DIM_X);
const int64_t scale_tile_y = static_cast<int64_t>(gemm_swizzle::GEMM_SWIZZLED_SCALE_TILE_DIM_Y);
const int64_t scale_tile_x = static_cast<int64_t>(gemm_swizzle::GEMM_SWIZZLED_SCALE_TILE_DIM_X);
// Padded byte size of the swizzled MXFP8 scale_inv for a single tensor with data
// shape (first, last). Rowwise scales use rows=first, cols=last; columnwise
// scales swap the orientation since they are stored in column-major order.
Expand All @@ -1125,10 +1125,10 @@ __forceinline__ __device__ int64_t padded_mxfp8_scale_inv_bytes(int64_t first, i

__forceinline__ __device__ int64_t padded_nvfp4_scale_inv_bytes(int64_t first, int64_t last,
bool rowwise) {
namespace mxfp8_swizzle = transformer_engine::dispatch::mxfp8::swizzle;
namespace gemm_swizzle = transformer_engine::dispatch::swizzle;
constexpr int64_t kNvfp4BlockSize = 16;
const int64_t scale_tile_y = static_cast<int64_t>(mxfp8_swizzle::GEMM_SWIZZLED_SCALE_TILE_DIM_Y);
const int64_t scale_tile_x = static_cast<int64_t>(mxfp8_swizzle::GEMM_SWIZZLED_SCALE_TILE_DIM_X);
const int64_t scale_tile_y = static_cast<int64_t>(gemm_swizzle::GEMM_SWIZZLED_SCALE_TILE_DIM_Y);
const int64_t scale_tile_x = static_cast<int64_t>(gemm_swizzle::GEMM_SWIZZLED_SCALE_TILE_DIM_X);
const int64_t scale_dim_y = rowwise ? first : last;
const int64_t data_dim_x = rowwise ? last : first;
const int64_t padded_scale_dim_y =
Expand Down
48 changes: 33 additions & 15 deletions transformer_engine/common/util/ptx.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -128,22 +128,22 @@ constexpr bool is_supported_arch() {

// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-init
__device__ __forceinline__ void mbarrier_init(uint64_t *mbar, const uint32_t count) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
uint32_t mbar_ptr = __cvta_generic_to_shared(mbar);
asm volatile("mbarrier.init.shared.b64 [%0], %1;" ::"r"(mbar_ptr), "r"(count) : "memory");
#else
NVTE_DEVICE_ERROR("mbarrier_init is only supported on SM 10.0+.");
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
NVTE_DEVICE_ERROR("mbarrier_init is only supported on SM 9.0+.");
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
}

// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-inval
__device__ __forceinline__ void mbarrier_invalid(uint64_t *mbar) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
uint32_t mbar_ptr = __cvta_generic_to_shared(mbar);
asm volatile("mbarrier.inval.shared.b64 [%0];" ::"r"(mbar_ptr) : "memory");
#else
NVTE_DEVICE_ERROR("mbarrier_invalid is only supported on SM 10.0+.");
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
NVTE_DEVICE_ERROR("mbarrier_invalid is only supported on SM 9.0+.");
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
}

// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-arrive
Expand All @@ -158,13 +158,13 @@ __device__ __forceinline__ void mbarrier_arrive(uint64_t *mbar) {

// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-arrive
__device__ __forceinline__ void mbarrier_arrive_expect_tx(uint64_t *mbar, const uint32_t tx_count) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
uint32_t mbar_ptr = __cvta_generic_to_shared(mbar);
asm volatile("mbarrier.arrive.expect_tx.shared.b64 _, [%0], %1;" ::"r"(mbar_ptr), "r"(tx_count)
: "memory");
#else
NVTE_DEVICE_ERROR("mbarrier_arrive_expect_tx is only supported on SM 10.0+.");
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
NVTE_DEVICE_ERROR("mbarrier_arrive_expect_tx is only supported on SM 9.0+.");
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
}

__device__ __forceinline__ void mbarrier_arrive_expect_tx_cta_relaxed_shared_cta(
Expand Down Expand Up @@ -230,8 +230,26 @@ __device__ __forceinline__ void cp_async_bulk_tensor_2d_global_to_shared(
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}

// global -> shared::cta (no cluster; valid on Hopper sm_90+ and Blackwell with
// cluster size 1). Used by the FP8 block-scaling grouped quantize kernels.
__device__ __forceinline__ void cp_async_bulk_tensor_2d_global_to_shared_cta(
uint64_t *dst_shmem, const uint64_t *tensor_map_ptr, const uint32_t offset_x,
const uint32_t offset_y, uint64_t *mbar) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
uint32_t dst_shmem_ptr = __cvta_generic_to_shared(dst_shmem);
uint32_t mbar_ptr = __cvta_generic_to_shared(mbar);
asm volatile(
"cp.async.bulk.tensor.2d.shared::cta.global.tile"
".mbarrier::complete_tx::bytes [%0], [%1, {%2, %3}], [%4];" ::"r"(dst_shmem_ptr),
"l"(tensor_map_ptr), "r"(offset_x), "r"(offset_y), "r"(mbar_ptr)
: "memory");
#else
NVTE_DEVICE_ERROR("cp_async_bulk_tensor_2d_global_to_shared_cta is only supported on SM 9.0+.");
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
}

__device__ __forceinline__ bool mbarrier_try_wait_parity(uint32_t mbar_ptr, const uint32_t parity) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
uint32_t waitComplete;
asm volatile(
"{\n\t .reg .pred P_OUT; \n\t"
Expand All @@ -243,19 +261,19 @@ __device__ __forceinline__ bool mbarrier_try_wait_parity(uint32_t mbar_ptr, cons
: "memory");
return static_cast<bool>(waitComplete);
#else
NVTE_DEVICE_ERROR("mbarrier_try_wait_parity is only supported on SM 10.0+.");
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
NVTE_DEVICE_ERROR("mbarrier_try_wait_parity is only supported on SM 9.0+.");
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
return true;
}

__device__ __forceinline__ void mbarrier_wait_parity(uint64_t *mbar, const uint32_t parity) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
uint32_t mbar_ptr = __cvta_generic_to_shared(mbar);
while (!mbarrier_try_wait_parity(mbar_ptr, parity)) {
}
#else
NVTE_DEVICE_ERROR("mbarrier_wait_parity is only supported on SM 10.0+.");
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
NVTE_DEVICE_ERROR("mbarrier_wait_parity is only supported on SM 9.0+.");
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
}

__device__ __forceinline__ void mbarrier_wait_parity_acquire_cta_shared_cta(uint64_t *mbar,
Expand Down
17 changes: 16 additions & 1 deletion transformer_engine/pytorch/csrc/extensions/cast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,7 @@ py::object group_quantize(const at::Tensor &tensor, py::handle quantizer, const
enum class GroupedQuantizationMode {
MXFP8_GROUPED_QUANTIZE,
NVFP4_GROUPED_QUANTIZE,
FP8_BLOCKWISE_GROUPED_QUANTIZE,
INVALID_FOR_GROUPED_QUANTIZE
};
GroupedQuantizationMode grouped_quantization_mode =
Expand All @@ -249,6 +250,8 @@ py::object group_quantize(const at::Tensor &tensor, py::handle quantizer, const
grouped_quantization_mode = GroupedQuantizationMode::MXFP8_GROUPED_QUANTIZE;
} else if (detail::IsNVFP4Quantizers(quantizer.ptr())) {
grouped_quantization_mode = GroupedQuantizationMode::NVFP4_GROUPED_QUANTIZE;
} else if (detail::IsFloat8BlockwiseQuantizers(quantizer.ptr())) {
grouped_quantization_mode = GroupedQuantizationMode::FP8_BLOCKWISE_GROUPED_QUANTIZE;
}

if (empty_input_buffer) {
Expand All @@ -274,9 +277,21 @@ py::object group_quantize(const at::Tensor &tensor, py::handle quantizer, const
});
break;
}
case GroupedQuantizationMode::FP8_BLOCKWISE_GROUPED_QUANTIZE: {
Float8BlockQuantizer *fp8_block_quantizer_cpp =
static_cast<Float8BlockQuantizer *>(quantizer_cpp.get());
QuantizationConfigWrapper quant_config_cpp;
quant_config_cpp.set_force_pow_2_scales(fp8_block_quantizer_cpp->force_pow_2_scales);
quant_config_cpp.set_amax_epsilon(fp8_block_quantizer_cpp->amax_epsilon);
NVTE_SCOPED_GIL_RELEASE({
nvte_group_quantize(grouped_input_tensor.data(), grouped_output_tensor_cpp.data(),
quant_config_cpp, at::cuda::getCurrentCUDAStream());
});
break;
}
case GroupedQuantizationMode::INVALID_FOR_GROUPED_QUANTIZE:
default:
NVTE_ERROR("group_quantize: only support NVFP4 or MXFP8 quantizer.");
NVTE_ERROR("group_quantize: only support NVFP4, MXFP8, or Float8Blockwise quantizer.");
break;
}

Expand Down
Loading