[asm] Add MXFP4 scaled MFMA support to Python and C++ backends and fix SGPR…#846
Merged
harsh-nod merged 1 commit intoiree-org:mainfrom Feb 10, 2026
Merged
[asm] Add MXFP4 scaled MFMA support to Python and C++ backends and fix SGPR…#846harsh-nod merged 1 commit intoiree-org:mainfrom
harsh-nod merged 1 commit intoiree-org:mainfrom
Conversation
nithinsubbiah
approved these changes
Feb 9, 2026
| ] | ||
|
|
||
| @tkw.wave(constraints) | ||
| def mxfp4_gemm_kernel( |
Contributor
There was a problem hiding this comment.
Is there a reason why we can't modify the existing MXFP test with asm backend?
Collaborator
Author
There was a problem hiding this comment.
Originally, to keep both backends separate (as requested by other teams) the tests were a copy of what already existed for the Python backend. I am just keeping that for separation for now, but we could unify them in the future.
…x SGPR conflict Add end-to-end support for MXFP4 scaled MFMA (v_mfma_scale_f32_16x16x128_f8f6f4) in both the Python and C++ wave_asm backends, and fix a critical SGPR register conflict that caused GPU memory access faults for kernels with 4+ arguments. C++ backend (wave_asm): - AMDGPUHandlers.cpp: Add getScaledMFMAFormatCode() to map MLIR float types (Float4E2M1FN, Float6E2M3FN, etc.) to cbsz/blgp hardware format codes, and attach them as attributes on the generated V_MFMA_SCALE op. - AssemblyEmitter.cpp: Add .Case handlers for V_MFMA_SCALE_F32_16X16X128 and V_MFMA_SCALE_F32_32X32X64 that emit cbsz/blgp modifiers. Add handlers for DS_WRITE_B8, DS_WRITE_B16, DS_READ_U8, DS_READ_U16. - TranslateFromMLIR.cpp: Fix LDS stores to use DS_WRITE_B8 for 1-byte and DS_WRITE_B16 for 2-byte stores (previously fell through to DS_WRITE_B32, corrupting adjacent LDS memory for scale factors). - InstructionInfo.cpp: Add InstrDesc entries for ds_write_b8/b16, ds_read_u8/u16. - SCFHandlers.cpp: Fix critical SGPR conflict - loop counters were hardcoded to s32, which overlaps with SRDs when kernels have 5+ arguments (e.g., MXFP4 with a, a_scale, b, b_scale, c). Now dynamically computes loop counter base from getFirstFreeSgprAfterSRDs(). - TranslateFromMLIR.h: Add getFirstFreeSgprAfterSRDs() that accounts for both regular and cache-swizzle SRDs to prevent register conflicts. Python backend: - kernel_mfma.py: Add cbsz:4 blgp:4 modifiers to scaled MFMA instructions. - kernel_ir.py: Add modifiers field to KInstr. - instruction_formatter.py: Support emitting instruction modifiers. - kernel_generator.py: Pass modifiers through to formatter. - handlers_memory.py: Add ds_read_b32 support for sub-4-byte LDS loads. - kernel_compilation_context.py: Add emit_lds_read_b32 helper. Tests: - test_asm_backend_e2e.py: Fix device_randn -> device_randint for int8 tensors, make ADDRESS_SPACE conditional on use_global_to_shared. - asm_backend_test.py: Same device_randn fix for Python backend tests. Signed-off-by: Harsh Menon <harsh.menon@amd.com>
nirmie
pushed a commit
to nirmie/wave
that referenced
this pull request
Mar 9, 2026
…x SGPR… (iree-org#846) … conflict Add end-to-end support for MXFP4 scaled MFMA (v_mfma_scale_f32_16x16x128_f8f6f4) in both the Python and C++ wave_asm backends, and fix a critical SGPR register conflict that caused GPU memory access faults for kernels with 4+ arguments. C++ backend (wave_asm): - AMDGPUHandlers.cpp: Add getScaledMFMAFormatCode() to map MLIR float types (Float4E2M1FN, Float6E2M3FN, etc.) to cbsz/blgp hardware format codes, and attach them as attributes on the generated V_MFMA_SCALE op. - AssemblyEmitter.cpp: Add .Case handlers for V_MFMA_SCALE_F32_16X16X128 and V_MFMA_SCALE_F32_32X32X64 that emit cbsz/blgp modifiers. Add handlers for DS_WRITE_B8, DS_WRITE_B16, DS_READ_U8, DS_READ_U16. - TranslateFromMLIR.cpp: Fix LDS stores to use DS_WRITE_B8 for 1-byte and DS_WRITE_B16 for 2-byte stores (previously fell through to DS_WRITE_B32, corrupting adjacent LDS memory for scale factors). - InstructionInfo.cpp: Add InstrDesc entries for ds_write_b8/b16, ds_read_u8/u16. - SCFHandlers.cpp: Fix critical SGPR conflict - loop counters were hardcoded to s32, which overlaps with SRDs when kernels have 5+ arguments (e.g., MXFP4 with a, a_scale, b, b_scale, c). Now dynamically computes loop counter base from getFirstFreeSgprAfterSRDs(). - TranslateFromMLIR.h: Add getFirstFreeSgprAfterSRDs() that accounts for both regular and cache-swizzle SRDs to prevent register conflicts. Python backend: - kernel_mfma.py: Add cbsz:4 blgp:4 modifiers to scaled MFMA instructions. - kernel_ir.py: Add modifiers field to KInstr. - instruction_formatter.py: Support emitting instruction modifiers. - kernel_generator.py: Pass modifiers through to formatter. - handlers_memory.py: Add ds_read_b32 support for sub-4-byte LDS loads. - kernel_compilation_context.py: Add emit_lds_read_b32 helper. Tests: - test_asm_backend_e2e.py: Fix device_randn -> device_randint for int8 tensors. - asm_backend_test.py: Same device_randn fix for Python backend tests. Signed-off-by: Harsh Menon <harsh.menon@amd.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
… conflict
Add end-to-end support for MXFP4 scaled MFMA (v_mfma_scale_f32_16x16x128_f8f6f4) in both the Python and C++ wave_asm backends, and fix a critical SGPR register conflict that caused GPU memory access faults for kernels with 4+ arguments.
C++ backend (wave_asm):
Python backend:
Tests: