Skip to content

[asm] Add MXFP4 scaled MFMA support to Python and C++ backends and fix SGPR…#846

Merged
harsh-nod merged 1 commit intoiree-org:mainfrom
harsh-nod:mxfp4_asm
Feb 10, 2026
Merged

[asm] Add MXFP4 scaled MFMA support to Python and C++ backends and fix SGPR…#846
harsh-nod merged 1 commit intoiree-org:mainfrom
harsh-nod:mxfp4_asm

Conversation

@harsh-nod
Copy link
Copy Markdown
Collaborator

… conflict

Add end-to-end support for MXFP4 scaled MFMA (v_mfma_scale_f32_16x16x128_f8f6f4) in both the Python and C++ wave_asm backends, and fix a critical SGPR register conflict that caused GPU memory access faults for kernels with 4+ arguments.

C++ backend (wave_asm):

  • AMDGPUHandlers.cpp: Add getScaledMFMAFormatCode() to map MLIR float types (Float4E2M1FN, Float6E2M3FN, etc.) to cbsz/blgp hardware format codes, and attach them as attributes on the generated V_MFMA_SCALE op.
  • AssemblyEmitter.cpp: Add .Case handlers for V_MFMA_SCALE_F32_16X16X128 and V_MFMA_SCALE_F32_32X32X64 that emit cbsz/blgp modifiers. Add handlers for DS_WRITE_B8, DS_WRITE_B16, DS_READ_U8, DS_READ_U16.
  • TranslateFromMLIR.cpp: Fix LDS stores to use DS_WRITE_B8 for 1-byte and DS_WRITE_B16 for 2-byte stores (previously fell through to DS_WRITE_B32, corrupting adjacent LDS memory for scale factors).
  • InstructionInfo.cpp: Add InstrDesc entries for ds_write_b8/b16, ds_read_u8/u16.
  • SCFHandlers.cpp: Fix critical SGPR conflict - loop counters were hardcoded to s32, which overlaps with SRDs when kernels have 5+ arguments (e.g., MXFP4 with a, a_scale, b, b_scale, c). Now dynamically computes loop counter base from getFirstFreeSgprAfterSRDs().
  • TranslateFromMLIR.h: Add getFirstFreeSgprAfterSRDs() that accounts for both regular and cache-swizzle SRDs to prevent register conflicts.

Python backend:

  • kernel_mfma.py: Add cbsz:4 blgp:4 modifiers to scaled MFMA instructions.
  • kernel_ir.py: Add modifiers field to KInstr.
  • instruction_formatter.py: Support emitting instruction modifiers.
  • kernel_generator.py: Pass modifiers through to formatter.
  • handlers_memory.py: Add ds_read_b32 support for sub-4-byte LDS loads.
  • kernel_compilation_context.py: Add emit_lds_read_b32 helper.

Tests:

  • test_asm_backend_e2e.py: Fix device_randn -> device_randint for int8 tensors.
  • asm_backend_test.py: Same device_randn fix for Python backend tests.

Copy link
Copy Markdown
Contributor

@panditsa panditsa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

]

@tkw.wave(constraints)
def mxfp4_gemm_kernel(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason why we can't modify the existing MXFP test with asm backend?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Originally, to keep both backends separate (as requested by other teams) the tests were a copy of what already existed for the Python backend. I am just keeping that for separation for now, but we could unify them in the future.

…x SGPR conflict

Add end-to-end support for MXFP4 scaled MFMA (v_mfma_scale_f32_16x16x128_f8f6f4)
in both the Python and C++ wave_asm backends, and fix a critical SGPR register
conflict that caused GPU memory access faults for kernels with 4+ arguments.

C++ backend (wave_asm):
- AMDGPUHandlers.cpp: Add getScaledMFMAFormatCode() to map MLIR float types
  (Float4E2M1FN, Float6E2M3FN, etc.) to cbsz/blgp hardware format codes, and
  attach them as attributes on the generated V_MFMA_SCALE op.
- AssemblyEmitter.cpp: Add .Case handlers for V_MFMA_SCALE_F32_16X16X128 and
  V_MFMA_SCALE_F32_32X32X64 that emit cbsz/blgp modifiers. Add handlers for
  DS_WRITE_B8, DS_WRITE_B16, DS_READ_U8, DS_READ_U16.
- TranslateFromMLIR.cpp: Fix LDS stores to use DS_WRITE_B8 for 1-byte and
  DS_WRITE_B16 for 2-byte stores (previously fell through to DS_WRITE_B32,
  corrupting adjacent LDS memory for scale factors).
- InstructionInfo.cpp: Add InstrDesc entries for ds_write_b8/b16, ds_read_u8/u16.
- SCFHandlers.cpp: Fix critical SGPR conflict - loop counters were hardcoded to
  s32, which overlaps with SRDs when kernels have 5+ arguments (e.g., MXFP4 with
  a, a_scale, b, b_scale, c). Now dynamically computes loop counter base from
  getFirstFreeSgprAfterSRDs().
- TranslateFromMLIR.h: Add getFirstFreeSgprAfterSRDs() that accounts for both
  regular and cache-swizzle SRDs to prevent register conflicts.

Python backend:
- kernel_mfma.py: Add cbsz:4 blgp:4 modifiers to scaled MFMA instructions.
- kernel_ir.py: Add modifiers field to KInstr.
- instruction_formatter.py: Support emitting instruction modifiers.
- kernel_generator.py: Pass modifiers through to formatter.
- handlers_memory.py: Add ds_read_b32 support for sub-4-byte LDS loads.
- kernel_compilation_context.py: Add emit_lds_read_b32 helper.

Tests:
- test_asm_backend_e2e.py: Fix device_randn -> device_randint for int8 tensors,
  make ADDRESS_SPACE conditional on use_global_to_shared.
- asm_backend_test.py: Same device_randn fix for Python backend tests.

Signed-off-by: Harsh Menon <harsh.menon@amd.com>
@harsh-nod harsh-nod merged commit c50d3f9 into iree-org:main Feb 10, 2026
14 of 15 checks passed
nirmie pushed a commit to nirmie/wave that referenced this pull request Mar 9, 2026
…x SGPR… (iree-org#846)

… conflict

Add end-to-end support for MXFP4 scaled MFMA
(v_mfma_scale_f32_16x16x128_f8f6f4) in both the Python and C++ wave_asm
backends, and fix a critical SGPR register conflict that caused GPU
memory access faults for kernels with 4+ arguments.

C++ backend (wave_asm):
- AMDGPUHandlers.cpp: Add getScaledMFMAFormatCode() to map MLIR float
types (Float4E2M1FN, Float6E2M3FN, etc.) to cbsz/blgp hardware format
codes, and attach them as attributes on the generated V_MFMA_SCALE op.
- AssemblyEmitter.cpp: Add .Case handlers for V_MFMA_SCALE_F32_16X16X128
and V_MFMA_SCALE_F32_32X32X64 that emit cbsz/blgp modifiers. Add
handlers for DS_WRITE_B8, DS_WRITE_B16, DS_READ_U8, DS_READ_U16.
- TranslateFromMLIR.cpp: Fix LDS stores to use DS_WRITE_B8 for 1-byte
and DS_WRITE_B16 for 2-byte stores (previously fell through to
DS_WRITE_B32, corrupting adjacent LDS memory for scale factors).
- InstructionInfo.cpp: Add InstrDesc entries for ds_write_b8/b16,
ds_read_u8/u16.
- SCFHandlers.cpp: Fix critical SGPR conflict - loop counters were
hardcoded to s32, which overlaps with SRDs when kernels have 5+
arguments (e.g., MXFP4 with a, a_scale, b, b_scale, c). Now dynamically
computes loop counter base from getFirstFreeSgprAfterSRDs().
- TranslateFromMLIR.h: Add getFirstFreeSgprAfterSRDs() that accounts for
both regular and cache-swizzle SRDs to prevent register conflicts.

Python backend:
- kernel_mfma.py: Add cbsz:4 blgp:4 modifiers to scaled MFMA
instructions.
- kernel_ir.py: Add modifiers field to KInstr.
- instruction_formatter.py: Support emitting instruction modifiers.
- kernel_generator.py: Pass modifiers through to formatter.
- handlers_memory.py: Add ds_read_b32 support for sub-4-byte LDS loads.
- kernel_compilation_context.py: Add emit_lds_read_b32 helper.

Tests:
- test_asm_backend_e2e.py: Fix device_randn -> device_randint for int8
tensors.
- asm_backend_test.py: Same device_randn fix for Python backend tests.

Signed-off-by: Harsh Menon <harsh.menon@amd.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants