-
Notifications
You must be signed in to change notification settings - Fork 631
[C] NVFP4 quantization for GroupedTensor
#2655
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com> Co-authored-by: Zhongbo Zhu <zhongboz@nvidia.com>
Greptile OverviewGreptile SummaryThis PR adds NVFP4 quantization support for Key Changes
Implementation DetailsThe graph-safe variants differ from existing implementations by:
The fusion kernel (
Notes
Confidence Score: 3/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant User
participant API as nvte_group_hadamard_transform_*_graph_safe
participant Wrapper as GroupedTensorWrapper
participant Kernel as CUDA Kernel (SM100+)
participant Device as Device Memory
User->>Wrapper: Create GroupedTensorWrapper
Wrapper->>Wrapper: nvte_create_grouped_tensor()
User->>Wrapper: set_rowwise_data(), set_columnwise_data(), etc.
Wrapper->>Device: Set device pointers for tensor params
User->>API: nvte_group_hadamard_transform_amax_graph_safe()
API->>API: Convert NVTEGroupedTensor
API->>API: Validate num_tensors > 0
alt NVFP4 Quantization with Fusion
User->>API: nvte_group_hadamard_transform_cast_fusion_graph_safe()
API->>Kernel: group_row_col_rht_gemm_ntt_w_sfc_graph_safe()
Kernel->>Device: TMA load input tensors
Kernel->>Kernel: Hadamard transform (RHT)
Kernel->>Kernel: NVFP4 quantization (stochastic rounding)
Kernel->>Device: Store rowwise/columnwise quantized data
Kernel->>Device: Store scale factors and amax
else Simple Amax Computation
API->>Kernel: GraphSafeGroupHadamardAmaxTmaKernel()
Kernel->>Device: TMA load input tensors
Kernel->>Kernel: Compute amax (with/without RHT)
Kernel->>Device: Store rowwise/columnwise amax
end
Kernel-->>API: Kernel completion
API-->>User: Return
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
3 files reviewed, 3 comments
| // TODO(zhongbo): double check the logic here | ||
| int group_idx = get_current_tensor_id(shape_rep, num_tensors, | ||
| (scheduler.tile_n_base() * size<1>(epilogue_tiler)) * M, | ||
| packed_N, M, offsets); | ||
|
|
||
| // Determine quantization scale factor layouts/output splits for this group | ||
| TSFDLayout sfd_layout; | ||
| int cur_N = static_cast<int>(first_dims[group_idx]); | ||
| if constexpr (kEnableSwizzleSFOutput) { | ||
| sfd_layout = tile_to_shape(SwizzledSFDLayoutAtom{}, make_shape(M, cur_N), Step<_2, _1>{}); | ||
| } else { | ||
| sfd_layout = make_layout(make_shape(M, make_shape(Int<SFVecSize>{}, cur_N / SFVecSize)), | ||
| make_stride(cur_N / SFVecSize, make_stride(_0{}, _1{}))); | ||
| } | ||
| // Build output tensors for columns and their quant scales | ||
| // TODO(zhongbo): double check the logic here |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
multiple TODO comments requesting logic verification in critical group index calculation and tensor layout code - verify group_idx calculation and tensor layout logic are correct before merging
| // TODO(zhongbo): double check the logic here | |
| int group_idx = get_current_tensor_id(shape_rep, num_tensors, | |
| (scheduler.tile_n_base() * size<1>(epilogue_tiler)) * M, | |
| packed_N, M, offsets); | |
| // Determine quantization scale factor layouts/output splits for this group | |
| TSFDLayout sfd_layout; | |
| int cur_N = static_cast<int>(first_dims[group_idx]); | |
| if constexpr (kEnableSwizzleSFOutput) { | |
| sfd_layout = tile_to_shape(SwizzledSFDLayoutAtom{}, make_shape(M, cur_N), Step<_2, _1>{}); | |
| } else { | |
| sfd_layout = make_layout(make_shape(M, make_shape(Int<SFVecSize>{}, cur_N / SFVecSize)), | |
| make_stride(cur_N / SFVecSize, make_stride(_0{}, _1{}))); | |
| } | |
| // Build output tensors for columns and their quant scales | |
| // TODO(zhongbo): double check the logic here | |
| // Determine the current tensor group index based on tile offset | |
| int group_idx = get_current_tensor_id(shape_rep, num_tensors, | |
| (scheduler.tile_n_base() * size<1>(epilogue_tiler)) * M, | |
| packed_N, M, offsets); | |
| // Determine quantization scale factor layouts/output splits for this group | |
| TSFDLayout sfd_layout; | |
| int cur_N = static_cast<int>(first_dims[group_idx]); | |
| if constexpr (kEnableSwizzleSFOutput) { | |
| sfd_layout = tile_to_shape(SwizzledSFDLayoutAtom{}, make_shape(M, cur_N), Step<_2, _1>{}); | |
| } else { | |
| sfd_layout = make_layout(make_shape(M, make_shape(Int<SFVecSize>{}, cur_N / SFVecSize)), | |
| make_stride(cur_N / SFVecSize, make_stride(_0{}, _1{}))); | |
| } | |
| // Build output tensors for columns and their quant scales | |
| Tensor mD = make_tensor(cute::subbyte_iterator<TD>(reinterpret_cast<TD *>( | |
| reinterpret_cast<char *>(QA_COLWISE) + offsets[group_idx] / 2)), | |
| make_shape(M, cur_N), DStride{}); // (M,packed_N) |
| // TODO(zhongbo): double check the logic here | ||
| int cur_group_idx = get_current_tensor_id(shape_rep, num_tensors, | ||
| global_tile_n_offset * M, packed_N, M, offsets); | ||
|
|
||
| if (cur_group_idx != group_idx) { | ||
| group_idx = cur_group_idx; | ||
| c_global_amax_val = shared_storage.global_d_amax[group_idx]; | ||
| // update amax | ||
| global_encode_scale = c_global_amax_val > 0.0f | ||
| ? cutlass::minimum_with_nan_propagation<float>{}( | ||
| (fp8_max * fp4_max) / c_global_amax_val, | ||
| cutlass::platform::numeric_limits<float>::max()) | ||
| : 1.0f; | ||
| global_decode_scale = 1.0f / global_encode_scale; | ||
| if constexpr (kUseFastMath) { | ||
| global_encode_scale_multiplier = global_encode_scale * fp4_max_inv; | ||
| } | ||
| // TODO(zhongbo): double check the logic here |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
more TODO comments in epilogue loop - verify group index recalculation and amax scaling logic
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
| using transformer_engine::detail::ShapeRepresentation; | ||
|
|
||
| void *input_base_ptr = reinterpret_cast<void *>(input->data.dptr); | ||
| // TODO(zhongbo): add input sanity checks here |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add input sanity checks as noted in TODO
|
Fixes #2510 |
Description
Pieces taken from #2600.
Type of change
Changes
Checklist: