Skip to content

[AIMIGRAPHX-885] MLP tower batched horizontal fusions#4723

Closed
TedThemistokleous wants to merge 42 commits into
developfrom
MLP_prediction_towers
Closed

[AIMIGRAPHX-885] MLP tower batched horizontal fusions#4723
TedThemistokleous wants to merge 42 commits into
developfrom
MLP_prediction_towers

Conversation

@TedThemistokleous

Copy link
Copy Markdown
Collaborator

Changes after further iteration/prototyping to ensure fusion doesn't break other horizontal passes, seeing a 4.8% improvment in lower batch model run performance

Motivation

Customer model performance improvements as part of AIMIGRAPHX-885

Found a fusion opportunity with 4 identical MLP towers into batched GEMM chains. This uses horizontal fusion so that we can perform the 1 to N operation and combining prediction towers into a smaller subset of fused ops

The change here performs a chain aware fusion so that we can identify which fusion chains we can compile while also preserving other dot based fusions (vertical fusions on dot for eg) by using layer depth > 2 as a way to filter the chains we want to optimize.

The goal here is to reduce the amount of dots and fuse them with sigmoids, mul into one kernel and avoid a bunch of operations on smaller tensors via batching operations into a larger op which reduces overhead of launch scaling.

Technical Details

Creates a fusion block to find, track (via hashing) and trace chains that capture a tower of MLP prediction layers.

Compared to baseline develop I was seeing a 5% improvement on batch 1 reducing 24 individual dots and combining everything in the MLP chains (Sigmoid, Mul, Dot, Add) into a mlir_dot_add_sigmoid_mul kernel.

The result is we take 24 operations into 3 Silu + 1 dot for these operations as it reduces 6 towers and 4 layers.

Changelog Category

Add a CHANGELOG.md entry for any option other than Not Applicable

    • Added: New functionality.
    • Changed: Changes to existing functionality.
    • Removed: Functionality or support that has been removed. (Compared to a previous release)
    • Optimized: Component performance that has been optimized or improved.
    • Resolved Issues: Known issues from a previous version that have been resolved.
    • Not Applicable: This PR is not to be included in the changelog.

…break other horizontal passes, seeing a 4.8% improvment in lower batch model run performance
@TedThemistokleous TedThemistokleous self-assigned this Mar 31, 2026
@TedThemistokleous TedThemistokleous changed the title AIMIGRAPHX-885 MLP tower batched horizontal fusions [AIMIGRAPHX-885] MLP tower batched horizontal fusions Mar 31, 2026
Comment thread src/fuse_horizontal.cpp Outdated
Comment thread src/fuse_horizontal.cpp Outdated
Comment thread src/fuse_horizontal.cpp Outdated
@codecov

codecov Bot commented Mar 31, 2026

Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 89.26380% with 35 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/simplify_reshapes.cpp 84.46% 30 Missing ⚠️
src/fuse_horizontal.cpp 96.24% 5 Missing ⚠️
Additional details and impacted files
@@             Coverage Diff             @@
##           develop    #4723      +/-   ##
===========================================
- Coverage    92.87%   92.83%   -0.04%     
===========================================
  Files          585      585              
  Lines        30122    30447     +325     
===========================================
+ Hits         27974    28264     +290     
- Misses        2148     2183      +35     
Files with missing lines Coverage Δ
src/fuse_horizontal.cpp 96.98% <96.24%> (-1.02%) ⬇️
src/simplify_reshapes.cpp 95.62% <84.46%> (-2.33%) ⬇️
🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

simplify some initial steps into seperate calls before tackling the last step thats quite messy
… for genealized MLP tower.

Allows us to generalize this type of pattern to not just SIlu but other actuvations and chains (layernorm, silu etc) much easier.
@TedThemistokleous TedThemistokleous marked this pull request as ready for review April 2, 2026 20:16
Fixes issue I was seeing since we were aliasing matches for other reshape patterns.
Comment thread src/simplify_reshapes.cpp
Comment thread src/simplify_algebra.cpp Outdated
Comment thread src/fuse_horizontal.cpp Outdated
Comment thread src/fuse_horizontal.cpp Outdated
Comment thread src/fuse_horizontal.cpp Outdated
Comment thread src/fuse_horizontal.cpp Outdated
TedThemistokleous and others added 4 commits April 7, 2026 23:06
The five slice_squeeze tests that relied on run_opt_pass (iterating
simplify_reshapes + simplify_algebra + CSE) belong in
optimize_module_test since they exercise the combined pass pipeline.
Move them there, repoint them at the existing run_pass (which runs
optimize_module), update the two expected outputs that change under
propagate_constant (constant folding collapses unsqueeze+concat of
literals), and remove run_opt_pass and the now-unused
simplify_algebra include from simplify_reshapes_test.

Made-with: Cursor
Already handeld with slice
@TedThemistokleous

TedThemistokleous commented Apr 9, 2026

Copy link
Copy Markdown
Collaborator Author

I would like to see a a perf report ran before we merge this in to make sure this doesnt cause a perf regression in other models. I know in the past when I have done changes like this, there is usually some other changes needed to avoid regressions in other models.

Got to the end of the line with this oen I think, I'm getting around a 2% boost after its all said and done...analyzing this further after changes from develop we're losing the mlir_slice_sigmoid_mul_dot_add_unsqueeze kernels. Did some more effort to add that back and making this a bit more robust gives us a small boost, and actually eliminates things further but the gain seems marginal.

Seeing develop going from 230-> 237 QPS and it adds a newly created mlir_dot_add_sigmoid_mul_squeeze

This ended up removing a 16 mlir_dot_add_sigmoid_mul kernels into 6 of the mlir_dot_add_sigmoid_mul_squeeze

Not sure if the effort here is worth the payoff - changes need more cleanup. we can discuss this further if you're free tomorrow/friday. I kept a bunch if intermediate perf IR reports if you wanna go over these

@TedThemistokleous

Copy link
Copy Markdown
Collaborator Author

@pfultz2 following up on this, without the --enable-offload-copy, we do see a perf boost so this isn't wasted work. Trying to tackle overhead right now as the offload copy adds so much that it overshadows any perf gains we see when using this on lower batch sizes. With the offload copy we get 2% but thats including hipMemCpy that take upwards of 50-60% of the run.

Comment thread src/fuse_horizontal.cpp

// Check if every dot feeds into add(dot, broadcast(...)). If so, return
// the add instructions and bias broadcast instructions.
static bool detect_downstream_biases(const std::vector<instruction_ref>& dots,

@pfultz2 pfultz2 Apr 10, 2026

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This shouldnt be fusing the downstream elementwise operators. That should be handled by the other horizontal fusions we have.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Which one? RIght now this gives me an 11% perf boost, but since the model I'm runnign this on has a copy_from_gpu after this there's a bit of a slowdown till I do something on the ORT side for memory coalescing for the outputs. This shows up when we do --enable-offload-copy too.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Which one?

find_splits in simplify_algebra.

Comment thread src/simplify_reshapes.cpp
@pfultz2

pfultz2 commented Apr 10, 2026

Copy link
Copy Markdown
Collaborator

I dont if maybe this should be split into 3 PRs, one for dot_horizontal_fusion, find_concat_reshape and another for find_slice_squeeze. At least the find_concat_reshape might need to be a PR on its own.

@TedThemistokleous

Copy link
Copy Markdown
Collaborator Author

I dont if maybe this should be split into 3 PRs, one for dot_horizontal_fusion, find_concat_reshape and another for find_slice_squeeze. At least the find_concat_reshape might need to be a PR on its own.

Not opposed to that, there's actually more changes ontop of this that added some additional changes until I hit a dead end I haven't pushed up to this PR. Likely need to get cleaned up again.

The bigger issue I believe is these updates that do fix and give a smaller boost (7% without offload copy) are completely overshadowed by the copies resulting really not much of a difference. Ive got another changeset Ill introduce once I'm done testing that handle some additional overhead with Context as I was seeing memory idle gaps in our execution pipeline between kernels for the lower batches.

@CharlieL7 CharlieL7 removed their request for review April 14, 2026 19:18
TedThemistokleous and others added 5 commits May 4, 2026 18:07
…slices

trace_pw_chain now detects diamond convergence patterns where a node
has exactly two outputs that merge at a single pointwise instruction
(e.g. SiLU: slice → sigmoid → mul(slice, sigmoid)).  The replay logic
navigates multi-output template instructions by op name rather than
assuming a single output.

Also fixes two correctness issues:
- Skip dead terminals (no consumers) to prevent infinite loop in the
  inner while(changed) iteration
- Require slices to exactly tile the bounding range to avoid wasteful
  computation over gaps

Made-with: Cursor
group_key now includes whether a dot's output feeds into a SiLU
pattern (optionally through a bias add).  This prevents dots with
SiLU from being batched with dots without SiLU, which previously
caused detect_downstream_silu to fail for the entire mixed group.

With this change, SiLU dots form their own batch (enabling full
bias + SiLU absorption) while non-SiLU dots are grouped separately.

Made-with: Cursor
…ting

slice_squeeze_pw_silu_chain: SiLU is now hoisted above slices, so
the expected output shows add → sigmoid → mul → slice → squeeze
instead of per-slice SiLU.

hoist_silu_above_slices_with_unsqueeze_concat: After hoisting, the
slice results are non-contiguous views, so concat(unsqueeze(slice))
cannot be simplified to a plain reshape. Updated m2 to match the
actual output.

Made-with: Cursor
@TedThemistokleous TedThemistokleous requested a review from bdevorem May 6, 2026 13:34
Comment thread src/simplify_reshapes.cpp
return {sig, current};
}

static void hoist_pointwise_above_slices(module& m)

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Delete

Comment thread src/simplify_reshapes.cpp
// outputs that converge to a single pointwise merge (e.g. SiLU:
// slice → sigmoid → mul(slice, sigmoid)).
static std::pair<pw_chain_sig, instruction_ref>
trace_pw_chain(instruction_ref start)

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Delete

@TedThemistokleous

Copy link
Copy Markdown
Collaborator Author

concat_unsqueeze - #4984
Two others to come

causten pushed a commit that referenced this pull request Jun 25, 2026
Adds a matcher for the concat_reshape matching. Partial piece when breaking up #4723
@TedThemistokleous

Copy link
Copy Markdown
Collaborator Author

closing as this is split out as #5014 #5004 #4984

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants