Skip to content

Paged Attention: rocMLIR backend changes#2221

Open
justinrosner wants to merge 4 commits into46-paged-attention-highlevelfrom
42-paged-attention-rocmlir
Open

Paged Attention: rocMLIR backend changes#2221
justinrosner wants to merge 4 commits into46-paged-attention-highlevelfrom
42-paged-attention-rocmlir

Conversation

@justinrosner
Copy link
Copy Markdown
Contributor

@justinrosner justinrosner commented Jan 30, 2026

Motivation

This PR implements the backend/lowering pipeline for paged attention in rocMLIR. The changes enable efficient tiled loading from non-contiguous paged memory by passing page table information through each lowering stage, culminating in ROCDL.raw_ptr_buffer_load instructions that load directly from page pointers.

Implements: https://amd-hub.atlassian.net/browse/AIROCMLIR-42

Technical Details

Overview of the flow:

rock.attention (with keyAddresses/valueAddresses)
       │
       ▼
rock.gridwise_attention_accel (extracts page table from rock.deref)
       │
       ▼
rock.blockwise_load_tile (with pageTable + pageSize)
       │
       ├──► Stage: PagePtrLoad (load page pointers to LDS)
       │
       └──► Stage: GlobalRead (with ldsPagePtrs + firstPageIndex)
                   │
                   ▼
            rock.threadwise_read_into (with paging attributes)
                   │
                   ▼
            rock.global_load (with pagePtr + pageSize)
                   │
                   ▼
            ROCDL.raw_ptr_buffer_load (final HW instruction)
  1. GridwiseGemmToBlockwise
  1. BlockwiseLoadTileToThreadwise
  • Splits the load into two stages for paged loads:
    • PagePtrLoad
      • Compute the firstPageIdx by evaluating transforms to get the tile's starting flat position
      • Allocates LDS buffer for page pointers
      • Each thread loads one page pointer from the global page table to LDS
      • Issues LDS barrier to synchronize before GlobalRead stage
    • GlobalRead: Creates ThreadwiseReadInto op with paging attributes
  1. ThreadwiseGemmLowering
  • Transform maps produce coordinates in [batch, pageIdx, offsetInPage] form
  • Computes LDS page index: globalPageIdx - firstPageIdx
  • Loads page pointer from LDS with bounds clamping
  • Adds validity check for null pointers (pages beyond table bounds)
  • Emits GlobalLoadOp with pagePtr and pageSize attributes
  1. SugarToLoops
  • Creates buffer resource (V#) from raw page pointer using ROCDL.make.buffer.rsrc
  • Converts i64 pointer to LLVM pointer type
  • Sets appropriate buffer descriptor flags (RDNA vs CDNA)
  • Uses page size as numRecords for bounds checking
  • Emits ROCDL.raw_ptr_buffer_load with:
  • Buffer resource from page pointer
  • Offset within page (in bytes)
  • Validity-guarded conditional load (scf.if)

Test Plan

  • PR CI

Test Result

  • PR CI

Submission Checklist

@justinrosner justinrosner marked this pull request as ready for review January 30, 2026 18:31
@justinrosner justinrosner requested a review from causten as a code owner January 30, 2026 18:31
@justinrosner justinrosner changed the title [WIP] Paged Attention: rocMLIR backend changes [WIP- NOT READY FOR REVIEW] Paged Attention: rocMLIR backend changes Jan 30, 2026
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This WIP pull request implements the backend/lowering pipeline for paged attention in rocMLIR, enabling efficient tiled loading from non-contiguous paged memory. The changes add support for passing page table information through the lowering stages, from high-level rock.attention operations down to low-level ROCDL.raw_ptr_buffer_load instructions that load directly from page pointers.

Changes:

  • Added paged memory support to Rock dialect operations (BlockwiseLoadTileOp, ThreadwiseReadIntoOp, GlobalLoadOp) with new attributes for page tables and page sizes
  • Implemented page pointer loading stage in BlockwiseLoadTileToThreadwise that loads page pointers to LDS before data loads
  • Extended lowering pipeline to compute page indices from logical positions and emit buffer resource instructions
  • Updated operandSegmentSizes across test files to accommodate new optional paging operands
  • Added support for Slice and AddDim transforms in TransformToMemref for paged memory access patterns
  • Implemented same-op interference tracking in ReuseLDS to prevent aliasing issues

Reviewed changes

Copilot reviewed 18 out of 18 changed files in this pull request and generated 8 comments.

Show a summary per file
File Description
mlir/test/Dialect/Rock/lowering_global_load_store.mlir Added test cases for paged load operations (scalar, vector, OOB, large pages)
mlir/test/Dialect/Rock/gridwise_gemm_conservative_lds_barriers.mlir Updated operandSegmentSizes for paged attention support
mlir/test/Dialect/Rock/gridwise_attention_accel_lowering_gqa.mlir Updated operandSegmentSizes for GQA paged attention
mlir/test/Dialect/Rock/gridwise_attention_accel_lowering_barriers.mlir Updated operandSegmentSizes for barrier tests
mlir/test/Dialect/Rock/gridwise_attention_accel_lowering.mlir Updated operandSegmentSizes and added comprehensive paged attention test
mlir/test/Dialect/Rock/gridwise-attention-prefix-causal.mlir Updated operandSegmentSizes for prefix causal attention
mlir/lib/Dialect/Rock/utility/transformMapUtils.cpp Implemented computeFlatPosition to evaluate transforms for page index calculation
mlir/lib/Dialect/Rock/Transforms/TransformToMemref.cpp Added Slice-to-subview and AddDim support for expand_shape transforms
mlir/lib/Dialect/Rock/Transforms/ThreadwiseGemmLowering.cpp Added paged load path with LDS page pointer lookup and validity checks
mlir/lib/Dialect/Rock/Transforms/SugarToLoops.cpp Implemented buffer resource creation and ROCDL lowering for paged loads
mlir/lib/Dialect/Rock/Transforms/ReuseLDS.cpp Added same-op interference tracking to prevent problematic LDS aliasing
mlir/lib/Dialect/Rock/Transforms/LowerRockReduce.cpp Updated GlobalLoadOp signature for paging parameters
mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp Extracted paging info from deref ops and disabled DirectToLDS for paged loads
mlir/lib/Dialect/Rock/Transforms/ConvToGemm.cpp Updated GlobalLoadOp signature for paging parameters
mlir/lib/Dialect/Rock/Transforms/BlockwiseLoadTileToThreadwise.cpp Implemented PagePtrLoad stage with page pointer loading to LDS
mlir/lib/Dialect/Rock/IR/RockDialect.cpp Added verification logic for paged operation attributes
mlir/include/mlir/Dialect/Rock/utility/transformMapUtils.h Added computeFlatPosition function declaration
mlir/include/mlir/Dialect/Rock/IR/RockOps.td Extended op definitions with paged attention attributes

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread mlir/lib/Dialect/Rock/Transforms/BlockwiseLoadTileToThreadwise.cpp
Comment thread mlir/lib/Dialect/Rock/Transforms/BlockwiseLoadTileToThreadwise.cpp
Comment thread mlir/lib/Dialect/Rock/Transforms/SugarToLoops.cpp
Comment thread mlir/lib/Dialect/Rock/Transforms/TransformToMemref.cpp
Comment thread mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp
Comment thread mlir/lib/Dialect/Rock/Transforms/BlockwiseLoadTileToThreadwise.cpp Outdated
Comment thread mlir/lib/Dialect/Rock/Transforms/BlockwiseLoadTileToThreadwise.cpp
Comment thread mlir/lib/Dialect/Rock/Transforms/ThreadwiseGemmLowering.cpp
@justinrosner justinrosner force-pushed the 42-paged-attention-rocmlir branch from 9959a7d to fa551da Compare January 30, 2026 21:56
@justinrosner justinrosner force-pushed the 42-paged-attention-rocmlir branch from fa551da to 034180c Compare February 2, 2026 22:14
@justinrosner justinrosner changed the title [WIP- NOT READY FOR REVIEW] Paged Attention: rocMLIR backend changes Paged Attention: rocMLIR backend changes Feb 4, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants