Hoist and horizontal dot#5014
Conversation
Hoist a pointwise (or SiLU diamond) chain that is replicated across several sibling slices of the same instruction above the slices: the chain is computed once on the bounding slice and each consumer reads its sub-range back out. This exposes a single wide pointwise op that downstream passes (find_splits, horizontal fusion) can recombine instead of N identical narrow ops. trace_pw_chain handles both linear pointwise chains and the SiLU diamond (x -> sigmoid -> mul(x, sigmoid)). Slices are only hoisted when they exactly tile the bounding range so no redundant elements are computed.
Cover the hoist transform in isolation (simplify_reshapes + DCE only): - relu hoisted above two sibling unit slices - SiLU diamond (sigmoid/mul) hoisted above sibling slices - no hoist when slices leave a gap in the bounding range - no hoist when sibling slices feed different pointwise chains
Batch structurally-identical dot operations (same activation/weight lens and output type, constant weights) into a single batched GEMM by stacking activations and weights along a new leading axis, then slice and squeeze the per-dot results back out. Per review on #4723, only the dots are fused here; downstream elementwise work (bias add, SiLU, ...) is left in place so the existing fusions (find_splits in simplify_algebra and the pointwise/MLIR fusions) can recombine it on top of the batched dot rather than duplicating that logic.
Cover the batched-dot fusion (fuse_horizontal + DCE): - two shape-identical dots with constant weights batch into one GEMM - dots with non-constant weights are not fused - dots with mismatched activation/weight shapes do not group
Exercise simplify_reshapes (hoist_pointwise_above_slices) and fuse_horizontal (dot_horizontal_fusion) in a single pipeline on a module containing both opportunities: a SiLU replicated across sibling slices and a set of parallel constant-weight dots. Verify the dots collapse to a single batched GEMM and the per-slice SiLUs collapse to one sigmoid/mul.
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## develop #5014 +/- ##
===========================================
+ Coverage 92.73% 92.74% +0.01%
===========================================
Files 596 596
Lines 31813 31896 +83
===========================================
+ Hits 29500 29581 +81
- Misses 2313 2315 +2
🚀 New features to boost your workflow:
|
| } | ||
|
|
||
| EXPECT(m1.sort() == m2.sort()); | ||
| } |
There was a problem hiding this comment.
This should already be handled by find_splits.
There was a problem hiding this comment.
Understood. Modified the changes so we can reuse as much of find_splits as possible, got rid of the netire tracing section so its just a case where we handle the smaller cases. Diamon pattern of ops to lift before slice should be handled like we talked about after fuse_pointwise.
Let me know what you think I'm just running some additional tests on this.
Regressions detected 🔴 * No develop baseline was found for this PR's branch point; compared against the latest available develop run instead. |
|
Replace the bespoke pointwise-chain hoisting engine (chain tracing, equivalence, replay) in find_splits with a unary-only hoist_bounded_splits that reuses the existing get_split_groups/is_fusable/split_groups_are_dependent helpers. When sibling slices only partially tile their tensor, a single-input op replicated across them is computed once on the bounding slice and each consumer re-slices its sub-range. Multi-op linear chains are peeled one op per fixpoint iteration, and SiLU/diamond patterns are handled once fuse_pointwise collapses them into a single pointwise op, so no diamond-specific code is needed here. Co-authored-by: Cursor <cursoragent@cursor.com>
The slim find_splits hoist only widens single-input ops, so a raw sigmoid/mul SiLU diamond is no longer hoisted directly. Update the two SiLU tests to run fuse_pointwise first (collapsing each per-slice SiLU into one pointwise op) and assert on the resulting fused pointwise: hoist_silu_above_slices now compares against a hoisted pointwise program, and the end-to-end test asserts a single pointwise plus a single batched dot. Co-authored-by: Cursor <cursoragent@cursor.com>
ensures we're not grabbing everything ith bias which was generated errnornous add_kernels by fusing dot early.
| mm->add_return({sub0, sub1}); | ||
| } | ||
| EXPECT(p1 == p2); | ||
| } |
There was a problem hiding this comment.
This test case should go into a different test suite like fuse_pointwise or fuse_pointwise_reduce.
| } | ||
|
|
||
| EXPECT(m1.sort() == m2.sort()); | ||
| } |
There was a problem hiding this comment.
It seems like we handle this already except we currently dont insert the initial slice.
| // op per fixpoint iteration (after the first hoist the sub-slices fully tile | ||
| // the wide result, so get_splits then applies), and a fused pointwise op | ||
| // (e.g. a SiLU produced by fuse_pointwise) already appears as a single op. | ||
| // ------------------------------------------------------------------------ |
There was a problem hiding this comment.
This comment should rewritten to make what it is doing clearer. As i understand it from the test case, we want slice{0,1} -> pw, slice{1,2} -> pw to be written as slice{0,2} -> pw -> (slice{0,1}, slice{1,2}).
| } | ||
| } | ||
| return changed; | ||
| } |
There was a problem hiding this comment.
Remove this function this already duplicate a lot of what is done in the main apply function.
| // The sibling slices only partially tile their tensor, so the | ||
| // full-cover split fusion does not apply; hoist a replicated | ||
| // single-input op above their bounding slice instead. | ||
| hoist_bounded_splits(m, ins); |
There was a problem hiding this comment.
We shouldn't go through a seperate path. We can add a partial flag to get_splits(default to false because this function is used elsewhere) where we skip checking that the slices start with 0 or ends is equal to the dimension(I couldn't find that check so I am not sure where this is happening).
Then before we insert the op we can check if the slices are partial then insert a slice and then the op. There is no need to reimplement another version of this function.
There was a problem hiding this comment.
Sure that can be done. Hold on
- move tests - Adjust changes to find_splits with partial flag.
Motivation
Final piece split from #4723
Adds in a hoist pass into simplify reshapes to move point wise operations above slices that is repeated across several slices. This sets us up to perofrm a horizontal dot fusion across each of the branches.
Technical Details
Changelog Category
Add a
CHANGELOG.mdentry for any option other thanNot Applicable