[AIMIGRAPHX-885] MLP tower batched horizontal fusions#4723
[AIMIGRAPHX-885] MLP tower batched horizontal fusions#4723TedThemistokleous wants to merge 42 commits into
Conversation
…break other horizontal passes, seeing a 4.8% improvment in lower batch model run performance
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## develop #4723 +/- ##
===========================================
- Coverage 92.87% 92.83% -0.04%
===========================================
Files 585 585
Lines 30122 30447 +325
===========================================
+ Hits 27974 28264 +290
- Misses 2148 2183 +35
🚀 New features to boost your workflow:
|
simplify some initial steps into seperate calls before tackling the last step thats quite messy
… for genealized MLP tower. Allows us to generalize this type of pattern to not just SIlu but other actuvations and chains (layernorm, silu etc) much easier.
Fixes issue I was seeing since we were aliasing matches for other reshape patterns.
…ruction in matcher
Do this to reuse some of the reshape matcher sthat already handle most of the functionality that this matcher uses.
The five slice_squeeze tests that relied on run_opt_pass (iterating simplify_reshapes + simplify_algebra + CSE) belong in optimize_module_test since they exercise the combined pass pipeline. Move them there, repoint them at the existing run_pass (which runs optimize_module), update the two expected outputs that change under propagate_constant (constant folding collapses unsqueeze+concat of literals), and remove run_opt_pass and the now-unused simplify_algebra include from simplify_reshapes_test. Made-with: Cursor
Already handeld with slice
Got to the end of the line with this oen I think, I'm getting around a 2% boost after its all said and done...analyzing this further after changes from develop we're losing the Seeing develop going from 230-> 237 QPS and it adds a newly created This ended up removing a 16 mlir_dot_add_sigmoid_mul kernels into 6 of the Not sure if the effort here is worth the payoff - changes need more cleanup. we can discuss this further if you're free tomorrow/friday. I kept a bunch if intermediate perf IR reports if you wanna go over these |
|
@pfultz2 following up on this, without the --enable-offload-copy, we do see a perf boost so this isn't wasted work. Trying to tackle overhead right now as the offload copy adds so much that it overshadows any perf gains we see when using this on lower batch sizes. With the offload copy we get 2% but thats including hipMemCpy that take upwards of 50-60% of the run. |
|
|
||
| // Check if every dot feeds into add(dot, broadcast(...)). If so, return | ||
| // the add instructions and bias broadcast instructions. | ||
| static bool detect_downstream_biases(const std::vector<instruction_ref>& dots, |
There was a problem hiding this comment.
This shouldnt be fusing the downstream elementwise operators. That should be handled by the other horizontal fusions we have.
There was a problem hiding this comment.
Which one? RIght now this gives me an 11% perf boost, but since the model I'm runnign this on has a copy_from_gpu after this there's a bit of a slowdown till I do something on the ORT side for memory coalescing for the outputs. This shows up when we do --enable-offload-copy too.
There was a problem hiding this comment.
Which one?
find_splits in simplify_algebra.
|
I dont if maybe this should be split into 3 PRs, one for |
Not opposed to that, there's actually more changes ontop of this that added some additional changes until I hit a dead end I haven't pushed up to this PR. Likely need to get cleaned up again. The bigger issue I believe is these updates that do fix and give a smaller boost (7% without offload copy) are completely overshadowed by the copies resulting really not much of a difference. Ive got another changeset Ill introduce once I'm done testing that handle some additional overhead with Context as I was seeing memory idle gaps in our execution pipeline between kernels for the lower batches. |
…slices trace_pw_chain now detects diamond convergence patterns where a node has exactly two outputs that merge at a single pointwise instruction (e.g. SiLU: slice → sigmoid → mul(slice, sigmoid)). The replay logic navigates multi-output template instructions by op name rather than assuming a single output. Also fixes two correctness issues: - Skip dead terminals (no consumers) to prevent infinite loop in the inner while(changed) iteration - Require slices to exactly tile the bounding range to avoid wasteful computation over gaps Made-with: Cursor
group_key now includes whether a dot's output feeds into a SiLU pattern (optionally through a bias add). This prevents dots with SiLU from being batched with dots without SiLU, which previously caused detect_downstream_silu to fail for the entire mixed group. With this change, SiLU dots form their own batch (enabling full bias + SiLU absorption) while non-SiLU dots are grouped separately. Made-with: Cursor
…ting slice_squeeze_pw_silu_chain: SiLU is now hoisted above slices, so the expected output shows add → sigmoid → mul → slice → squeeze instead of per-slice SiLU. hoist_silu_above_slices_with_unsqueeze_concat: After hoisting, the slice results are non-contiguous views, so concat(unsqueeze(slice)) cannot be simplified to a plain reshape. Updated m2 to match the actual output. Made-with: Cursor
| return {sig, current}; | ||
| } | ||
|
|
||
| static void hoist_pointwise_above_slices(module& m) |
| // outputs that converge to a single pointwise merge (e.g. SiLU: | ||
| // slice → sigmoid → mul(slice, sigmoid)). | ||
| static std::pair<pw_chain_sig, instruction_ref> | ||
| trace_pw_chain(instruction_ref start) |
|
concat_unsqueeze - #4984 |
Adds a matcher for the concat_reshape matching. Partial piece when breaking up #4723
Changes after further iteration/prototyping to ensure fusion doesn't break other horizontal passes, seeing a 4.8% improvment in lower batch model run performance
Motivation
Customer model performance improvements as part of AIMIGRAPHX-885
Found a fusion opportunity with 4 identical MLP towers into batched GEMM chains. This uses horizontal fusion so that we can perform the 1 to N operation and combining prediction towers into a smaller subset of fused ops
The change here performs a chain aware fusion so that we can identify which fusion chains we can compile while also preserving other dot based fusions (vertical fusions on dot for eg) by using layer depth > 2 as a way to filter the chains we want to optimize.
The goal here is to reduce the amount of dots and fuse them with sigmoids, mul into one kernel and avoid a bunch of operations on smaller tensors via batching operations into a larger op which reduces overhead of launch scaling.
Technical Details
Creates a fusion block to find, track (via hashing) and trace chains that capture a tower of MLP prediction layers.
Compared to baseline develop I was seeing a 5% improvement on batch 1 reducing 24 individual dots and combining everything in the MLP chains (Sigmoid, Mul, Dot, Add) into a
mlir_dot_add_sigmoid_mulkernel.The result is we take 24 operations into 3 Silu + 1 dot for these operations as it reduces 6 towers and 4 layers.
Changelog Category
Add a
CHANGELOG.mdentry for any option other thanNot Applicable