Paged Attention: rocMLIR backend changes#2221
Open
justinrosner wants to merge 4 commits into46-paged-attention-highlevelfrom
Open
Paged Attention: rocMLIR backend changes#2221justinrosner wants to merge 4 commits into46-paged-attention-highlevelfrom
justinrosner wants to merge 4 commits into46-paged-attention-highlevelfrom
Conversation
Contributor
There was a problem hiding this comment.
Pull request overview
This WIP pull request implements the backend/lowering pipeline for paged attention in rocMLIR, enabling efficient tiled loading from non-contiguous paged memory. The changes add support for passing page table information through the lowering stages, from high-level rock.attention operations down to low-level ROCDL.raw_ptr_buffer_load instructions that load directly from page pointers.
Changes:
- Added paged memory support to Rock dialect operations (BlockwiseLoadTileOp, ThreadwiseReadIntoOp, GlobalLoadOp) with new attributes for page tables and page sizes
- Implemented page pointer loading stage in BlockwiseLoadTileToThreadwise that loads page pointers to LDS before data loads
- Extended lowering pipeline to compute page indices from logical positions and emit buffer resource instructions
- Updated operandSegmentSizes across test files to accommodate new optional paging operands
- Added support for Slice and AddDim transforms in TransformToMemref for paged memory access patterns
- Implemented same-op interference tracking in ReuseLDS to prevent aliasing issues
Reviewed changes
Copilot reviewed 18 out of 18 changed files in this pull request and generated 8 comments.
Show a summary per file
| File | Description |
|---|---|
| mlir/test/Dialect/Rock/lowering_global_load_store.mlir | Added test cases for paged load operations (scalar, vector, OOB, large pages) |
| mlir/test/Dialect/Rock/gridwise_gemm_conservative_lds_barriers.mlir | Updated operandSegmentSizes for paged attention support |
| mlir/test/Dialect/Rock/gridwise_attention_accel_lowering_gqa.mlir | Updated operandSegmentSizes for GQA paged attention |
| mlir/test/Dialect/Rock/gridwise_attention_accel_lowering_barriers.mlir | Updated operandSegmentSizes for barrier tests |
| mlir/test/Dialect/Rock/gridwise_attention_accel_lowering.mlir | Updated operandSegmentSizes and added comprehensive paged attention test |
| mlir/test/Dialect/Rock/gridwise-attention-prefix-causal.mlir | Updated operandSegmentSizes for prefix causal attention |
| mlir/lib/Dialect/Rock/utility/transformMapUtils.cpp | Implemented computeFlatPosition to evaluate transforms for page index calculation |
| mlir/lib/Dialect/Rock/Transforms/TransformToMemref.cpp | Added Slice-to-subview and AddDim support for expand_shape transforms |
| mlir/lib/Dialect/Rock/Transforms/ThreadwiseGemmLowering.cpp | Added paged load path with LDS page pointer lookup and validity checks |
| mlir/lib/Dialect/Rock/Transforms/SugarToLoops.cpp | Implemented buffer resource creation and ROCDL lowering for paged loads |
| mlir/lib/Dialect/Rock/Transforms/ReuseLDS.cpp | Added same-op interference tracking to prevent problematic LDS aliasing |
| mlir/lib/Dialect/Rock/Transforms/LowerRockReduce.cpp | Updated GlobalLoadOp signature for paging parameters |
| mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp | Extracted paging info from deref ops and disabled DirectToLDS for paged loads |
| mlir/lib/Dialect/Rock/Transforms/ConvToGemm.cpp | Updated GlobalLoadOp signature for paging parameters |
| mlir/lib/Dialect/Rock/Transforms/BlockwiseLoadTileToThreadwise.cpp | Implemented PagePtrLoad stage with page pointer loading to LDS |
| mlir/lib/Dialect/Rock/IR/RockDialect.cpp | Added verification logic for paged operation attributes |
| mlir/include/mlir/Dialect/Rock/utility/transformMapUtils.h | Added computeFlatPosition function declaration |
| mlir/include/mlir/Dialect/Rock/IR/RockOps.td | Extended op definitions with paged attention attributes |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
9959a7d to
fa551da
Compare
fa551da to
034180c
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Motivation
This PR implements the backend/lowering pipeline for paged attention in rocMLIR. The changes enable efficient tiled loading from non-contiguous paged memory by passing page table information through each lowering stage, culminating in ROCDL.raw_ptr_buffer_load instructions that load directly from page pointers.
Implements: https://amd-hub.atlassian.net/browse/AIROCMLIR-42
Technical Details
Overview of the flow:
rock.derefopsPagePtrLoadGlobalReadstageGlobalRead: CreatesThreadwiseReadIntoop with paging attributesTest Plan
Test Result
Submission Checklist