[AIMIGRAPHX-408] Update intermediate ops to support dynamic shapes#4581
[AIMIGRAPHX-408] Update intermediate ops to support dynamic shapes#4581
Conversation
There was a problem hiding this comment.
Pull request overview
Updates GPU intermediate ops and MLIR fusion pipeline to better support fully dynamic-shape graphs (notably dot/gemm + pointwise), and adds/extends tests to validate behavior.
Changes:
- Make GPU contiguous/unary device eval paths handle dynamic output shapes correctly.
- Prevent
reduce_dimsfrom throwing on dynamic shapes and add coverage for dynamic-shape inputs. - Enable additional GPU compilation/fusion paths for dynamic graphs and add a dynamic GEMM+pointwise verify test.
Reviewed changes
Copilot reviewed 9 out of 10 changed files in this pull request and generated 1 comment.
Show a summary per file
| File | Description |
|---|---|
| test/verify/test_dynamic_pointwise.cpp | Minor comment formatting update in an existing dynamic pointwise verify test. |
| test/verify/test_dynamic_gemm_pointwise.cpp | Adds a new verify test covering dynamic batch/K GEMM followed by broadcast + pointwise ops (MLIR-gated). |
| test/verify/main.cpp | Registers the new dynamic GEMM+pointwise verify test instantiations and keeps dynamic pointwise tests listed. |
| test/reduce_dims.cpp | Adds a dynamic-shape test case for reduce_dims plus a small dynamic-shape helper. |
| src/reduce_dims.cpp | Early-return for dynamic shapes to avoid calling lens()/strides() on dynamic shapes. |
| src/targets/gpu/target.cpp | Runs fuse_pointwise_reduce unconditionally and enables fuse_mlir whenever MLIR is enabled (including full dynamic). |
| src/targets/gpu/include/migraphx/gpu/oper.hpp | Avoids reshaping through dynamic reduce_shapes and reshapes unary outputs using computed runtime shape when needed. |
| src/targets/gpu/include/migraphx/gpu/contiguous.hpp | Allows dynamic shapes in gpu::contiguous and returns a dynamic output shape when input is dynamic. |
| src/targets/gpu/fuse_mlir.cpp | Adjusts MLIR fusion matcher inheritance related to dynamic-shape matching. |
| src/targets/gpu/compile_ops.cpp | Makes runtime module naming robust when module_args are empty and aligns parameter naming with the selected runtime module name. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| struct find_mlir_fused_ops | ||
| { |
There was a problem hiding this comment.
Removing match::supports_dynamic_shapes from find_mlir_fused_ops means matcher.hpp will wrap the matcher with not_dynamic_shape(...), so this fusion will no longer trigger on dynamic-shaped graphs (dot/conv -> reshapes -> pointwise). If the intent of this PR is to support dynamic gemm+pointwise via MLIR fusion, consider re-adding match::supports_dynamic_shapes (or otherwise making the matcher explicitly dynamic-safe) so dynamic graphs still get fused; if this is intentional, it likely needs an in-code explanation because it counteracts enabling fuse_mlir under MIGRAPHX_ENABLE_FULL_DYNAMIC.
There was a problem hiding this comment.
What was the intent here?
There was a problem hiding this comment.
this doesnt actually work till we fix multibroadcasts for the dynamic case. Either I have to rewrite the logic to hadle that or this will work as it is once we rewrite dynamic broadcasts as single input ops with symbolic dimensions. Plan is to do the latter
There was a problem hiding this comment.
for now we are stuck with only using standalone ops
bdevorem
left a comment
There was a problem hiding this comment.
lgtm. The only thing I'd say is I'd appreciate more inline comments. Also, are you waiting until the work is done to update the changelog? That makes sense, but if you aren't, then just fyi that it hasn't been updated here
| struct find_mlir_fused_ops | ||
| { |
There was a problem hiding this comment.
What was the intent here?
Motivation
Enable running fully dynamic graphs with gemm and pointwise ops
Technical Details
Update contiguous op to work with dynamic shapes. Required for running mlir gemms
Changelog Category
Add a
CHANGELOG.mdentry for any option other thanNot ApplicableShould go in after #4549