simplify_reshapes: skip find_reshape_dot when it would change element…#4994
simplify_reshapes: skip find_reshape_dot when it would change element…#4994ycastill2-amd wants to merge 2 commits into
Conversation
… count find_reshape_dot moves a reshape across a dot and then reshapes the new dot's output back to the original dot's shape. When the rewrite changes the total element count (e.g. the second input's reshape preserves its contraction axis but changes N), that trailing reshape is invalid and compilation aborts with "Reshape: Wrong number of elements" (observed on an RT-DETR detection model on gfx1150). Skip the rewrite, and drop the speculatively-created dot, when the new dot's element count differs from the original dot's. Signed-off-by: Yviel Castillejos <ycastill@amd.com>
|
Thank you for your contribution! Since this is an external pull request, a maintainer must review PR and add the "ok-to-test" label if it is approved for testing. |
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## develop #4994 +/- ##
===========================================
- Coverage 92.73% 92.71% -0.03%
===========================================
Files 594 596 +2
Lines 31340 31494 +154
===========================================
+ Hits 29063 29197 +134
- Misses 2277 2297 +20
🚀 New features to boost your workflow:
|
|
Can you add a unit test for this? |
Add reshape_dot_changed_element_count, which builds a reshape -> dot -> reshape pattern where the second input's reshape preserves its contraction axis but changes N. Moving the reshapes across the dot would change the total element count and make the trailing reshape invalid, so find_reshape_dot must skip the rewrite and leave the module unchanged. Signed-off-by: Yviel Castillejos <ycastill@amd.com>
| // Moving the reshape across the dot is only valid when the element count is | ||
| // preserved; otherwise the trailing reshape back to the original dot shape is | ||
| // invalid (observed creating an N->M element reshape that aborts compilation). | ||
| if(new_dot->get_shape().elements() != dot->get_shape().elements()) |
There was a problem hiding this comment.
I think we should do this before we add the new instruction to avoid doing a remove_instruction below. Just have an early return. The number of elements can be inferred from the inp and new_other for the check.
There was a problem hiding this comment.
Pull request overview
This PR hardens the simplify_reshapes pass by preventing find_reshape_dot from performing a reshape-across-dot rewrite when it would change the dot output’s total element count, which would make the trailing reshape invalid and trigger a reshape element-count error during compilation.
Changes:
- Add an element-count guard in
find_reshape_dotand bail out when the rewrite would change the dot output size. - Add a regression test ensuring the pass leaves the module unchanged for a reshape-dot case where the rewrite would change element count.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.
| File | Description |
|---|---|
src/simplify_reshapes.cpp |
Adds element-count validation to skip the reshape-dot rewrite when it would make the trailing reshape invalid. |
test/simplify_reshapes_test.cpp |
Adds a regression test covering the element-count-changing reshape-dot scenario. |
| // Moving the reshape across the dot is only valid when the element count is | ||
| // preserved; otherwise the trailing reshape back to the original dot shape is | ||
| // invalid (observed creating an N->M element reshape that aborts compilation). | ||
| if(new_dot->get_shape().elements() != dot->get_shape().elements()) | ||
| { | ||
| m.remove_instruction(new_dot); | ||
| return; | ||
| } |
| // find_reshape_dot would move the reshapes across the dot and then reshape the new | ||
| // dot's output back to the original dot's shape. Here the second input's reshape | ||
| // preserves its contraction axis (second-to-last dim) but changes the free (N) dim, | ||
| // so the rewrite changes the dot's total element count and that trailing reshape is | ||
| // invalid (the "Reshape: Wrong number of elements" abort observed on an RT-DETR | ||
| // detection model on gfx1150). The pass must leave the module unchanged. |
… count
find_reshape_dot moves a reshape across a dot and then reshapes the new dot's output back to the original dot's shape. When the rewrite changes the total element count (e.g. the second input's reshape preserves its contraction axis but changes N), that trailing reshape is invalid and compilation aborts with "Reshape: Wrong number of elements" (observed on an RT-DETR detection model on gfx1150).
Skip the rewrite, and drop the speculatively-created dot, when the new dot's element count differs from the original dot's.
Motivation
Technical Details
Changelog Category
Add a
CHANGELOG.mdentry for any option other thanNot Applicable