Add MXFP8 support with cuBLASMp#3145
Conversation
|
MXFP8 comparison of cuBLASMp vs UB on DGX-B200:
|
c6517dd to
0c4c53a
Compare
Greptile SummaryThis PR enables MXFP8 (block scaling) support for the cuBLASMp comm+GEMM overlap path by extending
Confidence Score: 5/5Safe to merge; the MXFP8 canonicalization path follows the established FP8 pattern and is protected by both a compile-time version guard and runtime assertions on pointer validity. The new MXFP8 branch in canonicalize_input correctly keeps the transpose flag unchanged (MXFP8 columnwise data is not a transposed view), mirrors the cublaslt_gemm.cu reference path, and guards the entire path with CUBLASMP_VERSION checks. Validation is thorough: mixed scaling modes are rejected, swizzled-scale format is asserted, and per-direction null checks are in place. Test infrastructure change is mechanical and correct. No files require special attention. Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[cublasmp_gemm called] --> B{scaling_mode A & B}
B -->|both tensor_scaling| C[FP8 / BF16 path]
B -->|both mxfp8_scaling| D[MXFP8 path]
B -->|mixed / unknown| E[NVTE_CHECK fails]
D --> F{CUBLASMP_VERSION < 801?}
F -->|yes| G[NVTE_ERROR - runtime throw]
F -->|no| H[Check with_gemm_swizzled_scales]
H --> I[canonicalize_input A]
I --> J{transa?}
J -->|yes| K[use row-wise data, keep trans flag]
J -->|no| L[use columnwise data, keep trans flag]
H --> M[canonicalize_input B]
M --> N{transb?}
N -->|yes| O[use columnwise data, keep trans flag]
N -->|no| P[use row-wise data, keep trans flag]
K --> Q[Set A scale mode VEC32_UE8M0]
L --> Q
O --> R[Set B scale mode VEC32_UE8M0]
P --> R
Q --> U[cublasMpMatmul]
R --> U
C --> T[Set scale mode SCALAR_FP32]
T --> U
%%{init: {'theme': 'base', 'themeVariables': {"darkMode": true, "background": "#0d1117", "primaryColor": "#21262d", "primaryTextColor": "#e6edf3", "primaryBorderColor": "#8b949e", "lineColor": "#8b949e", "textColor": "#e6edf3", "edgeLabelBackground": "#161b22", "actorBkg": "#21262d", "actorBorder": "#8b949e", "actorTextColor": "#e6edf3", "actorLineColor": "#8b949e", "signalColor": "#8b949e", "signalTextColor": "#e6edf3", "noteBkgColor": "#373320", "noteBorderColor": "#d4a72c", "noteTextColor": "#f0e6c0", "labelBoxBkgColor": "#21262d", "labelBoxBorderColor": "#8b949e", "labelTextColor": "#e6edf3", "loopTextColor": "#e6edf3", "activationBkgColor": "#30363d", "activationBorderColor": "#8b949e"}}}%%
flowchart TD
A[cublasmp_gemm called] --> B{scaling_mode A & B}
B -->|both tensor_scaling| C[FP8 / BF16 path]
B -->|both mxfp8_scaling| D[MXFP8 path]
B -->|mixed / unknown| E[NVTE_CHECK fails]
D --> F{CUBLASMP_VERSION < 801?}
F -->|yes| G[NVTE_ERROR - runtime throw]
F -->|no| H[Check with_gemm_swizzled_scales]
H --> I[canonicalize_input A]
I --> J{transa?}
J -->|yes| K[use row-wise data, keep trans flag]
J -->|no| L[use columnwise data, keep trans flag]
H --> M[canonicalize_input B]
M --> N{transb?}
N -->|yes| O[use columnwise data, keep trans flag]
N -->|no| P[use row-wise data, keep trans flag]
K --> Q[Set A scale mode VEC32_UE8M0]
L --> Q
O --> R[Set B scale mode VEC32_UE8M0]
P --> R
Q --> U[cublasMpMatmul]
R --> U
C --> T[Set scale mode SCALAR_FP32]
T --> U
Reviews (2): Last reviewed commit: "Enable cuBLASMp MXFP8 overlap tests" | Re-trigger Greptile |
85280d5 to
d0559ca
Compare
c0e28c9 to
b7ce126
Compare
Description
Please include a brief summary of the changes, relevant motivation and context.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: