Skip to content

simplify_reshapes: skip find_reshape_dot when it would change element…#4994

Draft
ycastill2-amd wants to merge 2 commits into
ROCm:developfrom
ycastill2-amd:fix-reshape-dot-element-count
Draft

simplify_reshapes: skip find_reshape_dot when it would change element…#4994
ycastill2-amd wants to merge 2 commits into
ROCm:developfrom
ycastill2-amd:fix-reshape-dot-element-count

Conversation

@ycastill2-amd

Copy link
Copy Markdown

… count

find_reshape_dot moves a reshape across a dot and then reshapes the new dot's output back to the original dot's shape. When the rewrite changes the total element count (e.g. the second input's reshape preserves its contraction axis but changes N), that trailing reshape is invalid and compilation aborts with "Reshape: Wrong number of elements" (observed on an RT-DETR detection model on gfx1150).

Skip the rewrite, and drop the speculatively-created dot, when the new dot's element count differs from the original dot's.

Motivation

Technical Details

Changelog Category

Add a CHANGELOG.md entry for any option other than Not Applicable

    • Added: New functionality.
    • Changed: Changes to existing functionality.
    • Removed: Functionality or support that has been removed. (Compared to a previous release)
    • Optimized: Component performance that has been optimized or improved.
    • Resolved Issues: Known issues from a previous version that have been resolved.
    • Not Applicable: This PR is not to be included in the changelog.

… count

find_reshape_dot moves a reshape across a dot and then reshapes the new
dot's output back to the original dot's shape. When the rewrite changes the
total element count (e.g. the second input's reshape preserves its
contraction axis but changes N), that trailing reshape is invalid and
compilation aborts with "Reshape: Wrong number of elements" (observed on an
RT-DETR detection model on gfx1150).

Skip the rewrite, and drop the speculatively-created dot, when the new dot's
element count differs from the original dot's.

Signed-off-by: Yviel Castillejos <ycastill@amd.com>
@github-actions

Copy link
Copy Markdown
Contributor

Thank you for your contribution! Since this is an external pull request, a maintainer must review PR and add the "ok-to-test" label if it is approved for testing.

@codecov

codecov Bot commented Jun 19, 2026

Copy link
Copy Markdown

Codecov Report

✅ All modified and coverable lines are covered by tests.

Additional details and impacted files
@@             Coverage Diff             @@
##           develop    #4994      +/-   ##
===========================================
- Coverage    92.73%   92.71%   -0.03%     
===========================================
  Files          594      596       +2     
  Lines        31340    31494     +154     
===========================================
+ Hits         29063    29197     +134     
- Misses        2277     2297      +20     
Files with missing lines Coverage Δ
src/simplify_reshapes.cpp 97.97% <100.00%> (+<0.01%) ⬆️

... and 23 files with indirect coverage changes

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@pfultz2

pfultz2 commented Jun 19, 2026

Copy link
Copy Markdown
Collaborator

Can you add a unit test for this?

Add reshape_dot_changed_element_count, which builds a reshape -> dot -> reshape
pattern where the second input's reshape preserves its contraction axis but
changes N. Moving the reshapes across the dot would change the total element
count and make the trailing reshape invalid, so find_reshape_dot must skip the
rewrite and leave the module unchanged.

Signed-off-by: Yviel Castillejos <ycastill@amd.com>
@TedThemistokleous TedThemistokleous self-requested a review June 25, 2026 19:05
Comment thread src/simplify_reshapes.cpp
// Moving the reshape across the dot is only valid when the element count is
// preserved; otherwise the trailing reshape back to the original dot shape is
// invalid (observed creating an N->M element reshape that aborts compilation).
if(new_dot->get_shape().elements() != dot->get_shape().elements())

@TedThemistokleous TedThemistokleous Jun 25, 2026

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should do this before we add the new instruction to avoid doing a remove_instruction below. Just have an early return. The number of elements can be inferred from the inp and new_other for the check.

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR hardens the simplify_reshapes pass by preventing find_reshape_dot from performing a reshape-across-dot rewrite when it would change the dot output’s total element count, which would make the trailing reshape invalid and trigger a reshape element-count error during compilation.

Changes:

  • Add an element-count guard in find_reshape_dot and bail out when the rewrite would change the dot output size.
  • Add a regression test ensuring the pass leaves the module unchanged for a reshape-dot case where the rewrite would change element count.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.

File Description
src/simplify_reshapes.cpp Adds element-count validation to skip the reshape-dot rewrite when it would make the trailing reshape invalid.
test/simplify_reshapes_test.cpp Adds a regression test covering the element-count-changing reshape-dot scenario.

Comment thread src/simplify_reshapes.cpp
Comment on lines +1774 to +1781
// Moving the reshape across the dot is only valid when the element count is
// preserved; otherwise the trailing reshape back to the original dot shape is
// invalid (observed creating an N->M element reshape that aborts compilation).
if(new_dot->get_shape().elements() != dot->get_shape().elements())
{
m.remove_instruction(new_dot);
return;
}
Comment on lines +4554 to +4559
// find_reshape_dot would move the reshapes across the dot and then reshape the new
// dot's output back to the original dot's shape. Here the second input's reshape
// preserves its contraction axis (second-to-last dim) but changes the free (N) dim,
// so the rewrite changes the dot's total element count and that trailing reshape is
// invalid (the "Reshape: Wrong number of elements" abort observed on an RT-DETR
// detection model on gfx1150). The pass must leave the module unchanged.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants