Skip to content

Hoist and horizontal dot#5014

Open
TedThemistokleous wants to merge 16 commits into
developfrom
hoist_and_horizontal_dot
Open

Hoist and horizontal dot#5014
TedThemistokleous wants to merge 16 commits into
developfrom
hoist_and_horizontal_dot

Conversation

@TedThemistokleous

Copy link
Copy Markdown
Collaborator

Motivation

Final piece split from #4723

Adds in a hoist pass into simplify reshapes to move point wise operations above slices that is repeated across several slices. This sets us up to perofrm a horizontal dot fusion across each of the branches.

Technical Details

Changelog Category

Add a CHANGELOG.md entry for any option other than Not Applicable

    • Added: New functionality.
    • Changed: Changes to existing functionality.
    • Removed: Functionality or support that has been removed. (Compared to a previous release)
    • Optimized: Component performance that has been optimized or improved.
    • Resolved Issues: Known issues from a previous version that have been resolved.
    • Not Applicable: This PR is not to be included in the changelog.

Hoist a pointwise (or SiLU diamond) chain that is replicated across
several sibling slices of the same instruction above the slices: the
chain is computed once on the bounding slice and each consumer reads its
sub-range back out. This exposes a single wide pointwise op that
downstream passes (find_splits, horizontal fusion) can recombine instead
of N identical narrow ops.

trace_pw_chain handles both linear pointwise chains and the SiLU diamond
(x -> sigmoid -> mul(x, sigmoid)). Slices are only hoisted when they
exactly tile the bounding range so no redundant elements are computed.
Cover the hoist transform in isolation (simplify_reshapes + DCE only):
- relu hoisted above two sibling unit slices
- SiLU diamond (sigmoid/mul) hoisted above sibling slices
- no hoist when slices leave a gap in the bounding range
- no hoist when sibling slices feed different pointwise chains
Batch structurally-identical dot operations (same activation/weight lens
and output type, constant weights) into a single batched GEMM by stacking
activations and weights along a new leading axis, then slice and squeeze
the per-dot results back out.

Per review on #4723, only the dots are fused here; downstream elementwise
work (bias add, SiLU, ...) is left in place so the existing fusions
(find_splits in simplify_algebra and the pointwise/MLIR fusions) can
recombine it on top of the batched dot rather than duplicating that logic.
Cover the batched-dot fusion (fuse_horizontal + DCE):
- two shape-identical dots with constant weights batch into one GEMM
- dots with non-constant weights are not fused
- dots with mismatched activation/weight shapes do not group
Exercise simplify_reshapes (hoist_pointwise_above_slices) and
fuse_horizontal (dot_horizontal_fusion) in a single pipeline on a module
containing both opportunities: a SiLU replicated across sibling slices
and a set of parallel constant-weight dots. Verify the dots collapse to a
single batched GEMM and the per-slice SiLUs collapse to one sigmoid/mul.
@codecov

codecov Bot commented Jun 25, 2026

Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 97.67442% with 2 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/simplify_algebra.cpp 96.23% 2 Missing ⚠️
Additional details and impacted files
@@             Coverage Diff             @@
##           develop    #5014      +/-   ##
===========================================
+ Coverage    92.73%   92.74%   +0.01%     
===========================================
  Files          596      596              
  Lines        31813    31896      +83     
===========================================
+ Hits         29500    29581      +81     
- Misses        2313     2315       +2     
Files with missing lines Coverage Δ
src/fuse_horizontal.cpp 99.39% <100.00%> (+0.15%) ⬆️
src/simplify_reshapes.cpp 98.15% <ø> (ø)
src/simplify_algebra.cpp 97.51% <96.23%> (-0.05%) ⬇️
🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

}

EXPECT(m1.sort() == m2.sort());
}

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should already be handled by find_splits.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Understood. Modified the changes so we can reuse as much of find_splits as possible, got rid of the netire tracing section so its just a case where we handle the smaller cases. Diamon pattern of ops to lift before slice should be handled like we talked about after fuse_pointwise.

Let me know what you think I'm just running some additional tests on this.

Comment thread test/simplify_reshapes_test.cpp
Comment thread src/simplify_reshapes.cpp Outdated
Comment thread src/fuse_horizontal.cpp Outdated
@gh-app-migraphx-bot-pr-write

gh-app-migraphx-bot-pr-write Bot commented Jun 26, 2026

Copy link
Copy Markdown
Test Batch New Rate (38aa99) Old Rate (9ad497)* Diff Status
torchvision-resnet50 64 2,978.53 3,148.35 -5.39% 🔴
torchvision-resnet50_fp16 64 2,527.61 6,662.40 -62.06% 🔴
torchvision-densenet121 32 839.46 2,707.30 -68.99% 🔴
torchvision-densenet121_fp16 32 4,541.61 4,558.12 -0.36%
torchvision-inceptionv3 32 1,015.85 1,797.44 -43.48% 🔴
torchvision-inceptionv3_fp16 32 2,822.59 2,824.56 -0.07%
cadene-inceptionv4 16 823.21 807.16 1.99%
cadene-resnext64x4 16 784.25 387.68 102.29% 🔆
slim-mobilenet 64 8,212.61 8,283.02 -0.85%
slim-nasnetalarge 64 199.61 229.53 -13.03% 🔴
slim-resnet50v2 64 3,164.78 3,177.49 -0.40%
bert-mrpc-onnx 8 1,171.90 1,171.12 0.07%
bert-mrpc-tf 1 486.10 479.25 1.43%
pytorch-examples-wlang-gru 1 323.06 327.85 -1.46%
pytorch-examples-wlang-lstm 1 470.75 457.62 2.87%
torchvision-resnet50_1 1 285.84 765.65 -62.67% 🔴
cadene-dpn92_1 1 442.68 447.53 -1.08%
cadene-resnext101_1 1 78.48 366.68 -78.60% 🔴
onnx-taau-downsample 1 186.56 401.14 -53.49% 🔴
dlrm-criteoterabyte 1 30.81 32.57 -5.41% 🔴
dlrm-criteoterabyte_fp16 1 51.69 52.61 -1.74%
agentmodel 1 11,349.41 8,132.20 39.56% 🔆
unet_fp16 2 51.07 57.28 -10.84% 🔴
resnet50v1_fp16 1 935.04 932.12 0.31%
resnet50v1_int8 1 947.98 943.10 0.52%
bert_base_cased_fp16 64 1,073.86 1,102.38 -2.59%
bert_large_uncased_fp16 32 329.25 347.33 -5.20% 🔴
bert_large_fp16 1 204.99 205.67 -0.33%
distilgpt2_fp16 16 2,089.03 2,094.34 -0.25%
yolov5s 1 596.67 597.91 -0.21%
tinyllama 1 45.95 45.98 -0.07%
vicuna-fastchat 1 44.10 44.11 -0.01%
whisper-tiny-encoder 1 417.37 420.33 -0.70%
whisper-tiny-decoder 1 414.53 416.77 -0.54%
llama2_7b 1 20.37 20.50 -0.64%
qwen1.5-7b 1 23.60 23.60 -0.03%
phi3-3.8b 1 26.78 26.79 -0.02%
llama3-8b 1 3.22 21.77 -85.23% 🔴
whisper-large-encoder 1 10.28 10.31 -0.35%
whisper-large-decoder 1 105.74 105.11 0.61%
mistral-7b 1 23.66 23.78 -0.50%
FLUX.1-schnell 1 757.00 751.64 0.71%

Regressions detected 🔴

* No develop baseline was found for this PR's branch point; compared against the latest available develop run instead.

@gh-app-migraphx-bot-pr-write

gh-app-migraphx-bot-pr-write Bot commented Jun 26, 2026

Copy link
Copy Markdown
Test Status Result
bert-mrpc-onnx PASSED: MIGraphX meets tolerance
bert-mrpc-tf PASSED: MIGraphX meets tolerance
pytorch-examples-wlang-gru PASSED: MIGraphX meets tolerance
pytorch-examples-wlang-lstm PASSED: MIGraphX meets tolerance
dlrm-criteoterabyte PASSED: MIGraphX meets tolerance
agentmodel PASSED: MIGraphX meets tolerance
unet PASSED: MIGraphX meets tolerance
resnet50v1 PASSED: MIGraphX meets tolerance
bert_base_cased_fp16 PASSED: MIGraphX meets tolerance
bert_large_uncased_fp16 🔴 FAILED: MIGraphX is not within tolerance - check verbose output
bert_large PASSED: MIGraphX meets tolerance
yolov5s PASSED: MIGraphX meets tolerance
tinyllama PASSED: MIGraphX meets tolerance
vicuna-fastchat PASSED: MIGraphX meets tolerance
whisper-tiny-encoder PASSED: MIGraphX meets tolerance
whisper-tiny-decoder PASSED: MIGraphX meets tolerance
distilgpt2_fp16 PASSED: MIGraphX meets tolerance
llama2_7b PASSED: MIGraphX meets tolerance
qwen1.5-7b PASSED: MIGraphX meets tolerance
phi3-3.8b PASSED: MIGraphX meets tolerance
llama3-8b PASSED: MIGraphX meets tolerance
whisper-large-encoder PASSED: MIGraphX meets tolerance
whisper-large-decoder PASSED: MIGraphX meets tolerance
mistral-7b PASSED: MIGraphX meets tolerance
FLUX.1-schnell PASSED: MIGraphX meets tolerance

TedThemistokleous and others added 8 commits June 26, 2026 14:57
Replace the bespoke pointwise-chain hoisting engine (chain tracing, equivalence,
replay) in find_splits with a unary-only hoist_bounded_splits that reuses the
existing get_split_groups/is_fusable/split_groups_are_dependent helpers.

When sibling slices only partially tile their tensor, a single-input op
replicated across them is computed once on the bounding slice and each consumer
re-slices its sub-range. Multi-op linear chains are peeled one op per fixpoint
iteration, and SiLU/diamond patterns are handled once fuse_pointwise collapses
them into a single pointwise op, so no diamond-specific code is needed here.

Co-authored-by: Cursor <cursoragent@cursor.com>
The slim find_splits hoist only widens single-input ops, so a raw sigmoid/mul
SiLU diamond is no longer hoisted directly. Update the two SiLU tests to run
fuse_pointwise first (collapsing each per-slice SiLU into one pointwise op) and
assert on the resulting fused pointwise: hoist_silu_above_slices now compares
against a hoisted pointwise program, and the end-to-end test asserts a single
pointwise plus a single batched dot.

Co-authored-by: Cursor <cursoragent@cursor.com>
ensures we're not grabbing everything ith bias which was generated errnornous add_kernels by fusing dot early.
Comment thread test/simplify_algebra_test.cpp Outdated
mm->add_return({sub0, sub1});
}
EXPECT(p1 == p2);
}

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test case should go into a different test suite like fuse_pointwise or fuse_pointwise_reduce.

}

EXPECT(m1.sort() == m2.sort());
}

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like we handle this already except we currently dont insert the initial slice.

Comment thread src/simplify_algebra.cpp Outdated
// op per fixpoint iteration (after the first hoist the sub-slices fully tile
// the wide result, so get_splits then applies), and a fused pointwise op
// (e.g. a SiLU produced by fuse_pointwise) already appears as a single op.
// ------------------------------------------------------------------------

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment should rewritten to make what it is doing clearer. As i understand it from the test case, we want slice{0,1} -> pw, slice{1,2} -> pw to be written as slice{0,2} -> pw -> (slice{0,1}, slice{1,2}).

Comment thread src/simplify_algebra.cpp Outdated
}
}
return changed;
}

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove this function this already duplicate a lot of what is done in the main apply function.

Comment thread src/simplify_algebra.cpp Outdated
// The sibling slices only partially tile their tensor, so the
// full-cover split fusion does not apply; hoist a replicated
// single-input op above their bounding slice instead.
hoist_bounded_splits(m, ins);

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We shouldn't go through a seperate path. We can add a partial flag to get_splits(default to false because this function is used elsewhere) where we skip checking that the slices start with 0 or ends is equal to the dimension(I couldn't find that check so I am not sure where this is happening).

Then before we insert the op we can check if the slices are partial then insert a slice and then the op. There is no need to reimplement another version of this function.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure that can be done. Hold on

- move tests
- Adjust changes to find_splits with partial flag.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

high priority A PR with high priority for review and merging.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants