Add squeeze, gather, and scatter_none ops to MIGraphX dialect#2176
Add squeeze, gather, and scatter_none ops to MIGraphX dialect#2176justinrosner wants to merge 12 commits intodevelopfrom
Conversation
|
Note for reviewers, this PR is operating under the assumption that we will have the upstream changes in soon so that we can have this change: llvm/llvm-project#167894 (adding i64 type support for indices in scatter/gather). To be conservative, I can keep the conversion from i64 -> i32 and just add a TODO and file a ticket to make that simple change once the upstream PR makes its way in. Let me know what you guys want. |
There was a problem hiding this comment.
Pull request overview
This PR adds three new operations to the MIGraphX dialect as part of the paged attention work: squeeze, gather, and scatter_none. These ops are based on ONNX specifications and include complete implementations with verifiers, TOSA lowering patterns, and comprehensive test coverage.
- Implements ONNX-compatible squeeze (removing size-1 dimensions), gather (collecting slices along an axis), and scatter_none (updating elements at specified indices)
- Adds verifiers for axis bounds checking, shape validation, and index constraints
- Provides TOSA lowering implementations that reshape tensors to match TOSA's 3D gather/scatter semantics
Reviewed changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated 8 comments.
Show a summary per file
| File | Description |
|---|---|
| mlir/include/mlir/Dialect/MIGraphX/IR/MIGraphX.td | Adds TableGen definitions for squeeze, gather, and scatter_none ops with documentation and assembly formats |
| mlir/lib/Dialect/MIGraphX/IR/MIGraphX.cpp | Implements verify() methods for the three new ops including axis bounds checking and shape validation |
| mlir/lib/Conversion/MIGraphXToTosa/MIGraphXToTosa.cpp | Adds converter patterns to lower each new op to TOSA operations via reshape/gather/scatter transformations |
| mlir/test/Dialect/MIGraphX/invalid.mlir | Adds negative test cases for axis out of bounds, rank mismatches, and invalid indices |
| mlir/test/Conversion/MIGraphXToTosa/migraphx-to-tosa.mlir | Adds positive conversion tests for squeeze, gather, and scatter_none with various axis configurations |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| }]; | ||
| } | ||
|
|
||
| def MIGraphX_SqueezeOp |
There was a problem hiding this comment.
You dont need to add squeeze and unsqueeze as we already lower this to reshape when going to mlir.
There was a problem hiding this comment.
The IR that I was basing this off of (which came from MIGraphX) was using the unlowered squeeze/unsqueeze. I think it's a fairly trivial lowering, so it wouldn't hurt if rocMLIR could also do something like this.
|
Outside of paged attention, will this enable us to fuse gather and scatter into gemms and convolutions? |
This PR in of itself won't get us to fusing gather/scatter operations. We would still need an additional backend lowering PR for this. |
| let hasVerifier = 1; | ||
| } | ||
|
|
||
| def MIGraphX_ScatterNoneOp |
There was a problem hiding this comment.
why is it called ScatterNone?
There was a problem hiding this comment.
MIGraphX was passing us a scatter_none op which directly correlates to the ScatterElements ONNX op with reduction set to none: https://onnx.ai/onnx/operators/onnx__ScatterElements.html
| } | ||
| } | ||
|
|
||
| // Build the new shape by excluding the squeezed axes |
There was a problem hiding this comment.
should we check that axesToSqueeze are all dimension=1?
There was a problem hiding this comment.
The verifier already does this validation.
There was a problem hiding this comment.
Or is that not enough? Since if I recall correctly the verifier isn't turned on by default?
| auto outputType = cast<RankedTensorType>(outputTy); | ||
| Type elemType = dataType.getElementType(); | ||
|
|
||
| // Lowering strategy for migraphx.gather -> tosa.gather: |
There was a problem hiding this comment.
why not skip tosa instead of this workaround?
There was a problem hiding this comment.
I was thinking about needing a lowering path through TOSA for the CPU path in the future (there currently isn't a lowering path that goes past Tosa). I've already opened up a ticket to address this: https://github.com/ROCm/rocMLIR-internal/issues/2205. It looks like IREE has already encountered this before and they have a custom lowering that maybe we could port/use.
| // Output: [N, W, C] where each [n, w, :] = reshaped_data[n, indices[n, w], :] | ||
| SmallVector<int64_t> gatherOutputShape = {N, W, C}; | ||
| auto gatherOutputType = RankedTensorType::get(gatherOutputShape, elemType); | ||
| Value gatherResult = tosa::GatherOp::create(rewriter, loc, gatherOutputType, |
There was a problem hiding this comment.
I understand we are missing tosa.gather -> rock? are we going to support this for paged attention for now?
There was a problem hiding this comment.
The tosa.gather -> rock was going to come in a future PR that handled the TosaToRock changes. However, in recent conversations regarding paged attention, it seems like we are moving away from the scatter/gather implementation.
There was a problem hiding this comment.
do we need this PR then?
There was a problem hiding this comment.
@pfultz2 Are there any plans for MIGraphX to require rocMLIR supporting gather/scatter outside of paged attention (think I might have heard something in one of the meetings yesterday)? If not, then I think we can close this.
There was a problem hiding this comment.
Fusing gather for gemms and convolutions will be useful as well, especially for resize.
| if (axis < 0) | ||
| axis += dataRank; | ||
|
|
||
| // TOSA scatter requires that indices be constant across the "C" dimension |
There was a problem hiding this comment.
same here, why not skip tosa?
There was a problem hiding this comment.
See comment above about CPU lowering.
37703a4 to
a4462ca
Compare
|
Closing out this PR for now to reflect the change in scope (supporting |
Motivation
This PR adds three new ops to the MIGraphX dialect (squeeze, gather, and scatter_none). These new ops are required as part of the paged attention work.
This implements https://github.com/ROCm/rocMLIR-internal/issues/2200
Technical Details
This PR can be broken down into the following changes:
MIGraphX.td:migraphx.squeeze(https://onnx.ai/onnx/operators/onnx__Squeeze.html): Remove dimensions of size 1migraphx.gather(https://onnx.ai/onnx/operators/onnx__Gather.html): Gather slices from data along an axismigraphx.scatter_none(https://onnx.ai/onnx/operators/onnx__ScatterElements.html): Updates into data at specified indicesMIGraphX.cpp: Add verifiers for each of the new opsMIGraphXToTosa: Lowers each of the new ops to their set of TOSA opsTest Plan
Test Result
Submission Checklist