Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
391 changes: 391 additions & 0 deletions docs/arch/fusion.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,391 @@
.. Licensed to the Apache Software Foundation (ASF) under one
or more contributor license agreements. See the NOTICE file
distributed with this work for additional information
regarding copyright ownership. The ASF licenses this file
to you under the Apache License, Version 2.0 (the
"License"); you may not use this file except in compliance
with the License. You may obtain a copy of the License at

.. http://www.apache.org/licenses/LICENSE-2.0

.. Unless required by applicable law or agreed to in writing,
software distributed under the License is distributed on an
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
KIND, either express or implied. See the License for the
specific language governing permissions and limitations
under the License.

.. _fusion-arch:

Operator Fusion
===============

Operator fusion is one of the most impactful optimizations in TVM. Instead of launching one kernel
per operator (e.g., conv2d, bias_add, relu), fusion merges multiple operators into a single kernel,
eliminating intermediate memory allocations and kernel launch overhead.

TVM provides two complementary fusion mechanisms:

- **Automatic fusion** (``FuseOps`` + ``FuseTIR``): groups operators based on their computational
patterns using a post-dominator analysis algorithm.
- **Pattern-based fusion** (``FuseOpsByPattern``): groups operators that match user-defined
dataflow patterns, typically for offloading to external backends (cuBLAS, CUTLASS, DNNL, etc.).

Both produce the same output: Relax functions marked with ``Primitive=True`` that are later
lowered to fused TIR kernels or dispatched to external libraries.

Overview
--------

Fusion involves three passes:

.. code-block:: text

IRModule (after LegalizeOps)
▼ AnnotateTIROpPattern ← label each op (elementwise, reduce, etc.)
IRModule (annotated)
▼ FuseOps ← group ops into fused Relax functions
IRModule (with fused functions marked Primitive=True)
▼ FuseTIR ← merge TIR PrimFuncs inside each group
IRModule (fused TIR kernels)

In the compilation pipeline, these passes appear in the backend-specific ``legalize_passes``
phase. For example, the CUDA pipeline (``python/tvm/relax/backend/cuda/pipeline.py``) runs:

.. code-block:: python

LegalizeOps() # lower Relax ops to call_tir
AnnotateTIROpPattern() # annotate pattern kinds
FoldConstant()
FuseOps() # group ops
FuseTIR() # merge TIR functions


Operator Pattern Classification
-------------------------------

Before fusion, ``AnnotateTIROpPattern`` analyzes each TIR function in the module and assigns
an ``OpPatternKind``. The fusion algorithm uses these pattern kinds to decide which operators
can be fused together.

.. list-table::
:header-rows: 1
:widths: 20 10 70

* - Pattern Kind
- Value
- Description
* - ``kElemWise``
- 0
- Elementwise: one-to-one input/output mapping (e.g., ``add``, ``relu``, ``exp``).
* - ``kBroadcast``
- 1
- Broadcasting: output axes map to input axes in order, but some input axes may be
broadcast (e.g., ``bias_add``). Note: ``transpose`` is **not** broadcast because axes
are reordered.
* - ``kInjective``
- 2
- Injective: each output element depends on a single input element, but the mapping may
be non-trivial (e.g., ``reshape``, ``concatenate``, ``transpose``).
* - ``kCommReduce``
- 3
- Communicative reduction: output elements aggregate over input elements
(e.g., ``sum``, ``max``, ``mean``).
* - ``kOutEWiseFusable``
- 4
- Complex operation whose output can accept elementwise followers, but cannot chain
with another complex op (e.g., ``conv2d``, ``matmul``, ``dense``).
* - ``kTuple``
- 7
- Tuple node. Can fuse into subsequent injective ops but is treated specially.
* - ``kOpaque``
- 8
- Opaque: cannot be fused (e.g., external function calls, operations with side effects).

These kinds form an ordering: lower values are "simpler" and more fusable. The fusion algorithm
uses ``CombinePattern(lhs, rhs) = max(lhs, rhs)`` when merging patterns along a path.


FuseOps: Automatic Fusion
-------------------------

``FuseOps`` (``src/relax/transform/fuse_ops.cc``) groups bindings in a dataflow block into
new Relax functions. It operates only within ``DataflowBlock``\ s — if your module doesn't have
any, run ``ConvertToDataflow`` first.

Algorithm
~~~~~~~~~

The fusion algorithm addresses diamond-shaped dataflow branches, where a single producer
(e.g., conv2d) has multiple consumers that eventually reconverge:

.. code-block:: text

conv2d
/ | \
/ | \
op op op
\ | /
\ | /
elemwise add

At the point of ``conv2d``, we don't know if all future paths will merge. The algorithm uses
**post-dominator analysis** to resolve this:

1. **Build forward graph**: construct an ``IndexedForwardGraph`` from the dataflow block.
Each node has an ``OpPatternKind`` and a list of forward edges.

2. **Build post-dominator tree**: compute the immediate post-dominator of each node using
Least Common Ancestor (LCA) on the DAG. The post-dominator of a node is the closest
downstream node where **all** future paths converge.

3. **Fuse groups**: for each node in topological order, check if it can be fused with its
immediate post-dominator:

- **CheckPath**: verify that all paths from the node to its post-dominator satisfy the
fusion conditions (pattern compatibility, depth limits, argument limits).
- **CommitFuse**: mark all intermediate nodes as belonging to the same group using a
Union-Find data structure.

4. **Create grouped functions**: extract each group into a new ``relax.Function`` with the
attribute ``Primitive=True``. Replace the original bindings with a call to the grouped
function.

Fusion rules
~~~~~~~~~~~~

The key fusion decisions depend on the ``OpPatternKind`` of the source, the path, and the
post-dominator. The algorithm runs in three phases (via ``GraphPartitioner::RunFuse``) so that
higher-complexity ops get a chance to fuse first:

- **Phase 0**: ``kOutEWiseFusable`` ops (e.g., ``conv2d``) can fuse with their elementwise
post-dominator if all intermediate ops are broadcast or simpler. This enables patterns like
conv2d + bias_add + relu. Two ``kOutEWiseFusable`` ops cannot fuse together.
- **Phase 1**: ``kInjective`` and ``kTuple`` ops can fuse only when all paths to the
post-dominator are injective or simpler. This is deferred to phase 1 so that
``kOutEWiseFusable`` groups are finalized first.
- **Phase 2**: fuse injective ops into intermediate tuple nodes that have already been absorbed
by subsequent injective groups.

``kElemWise`` / ``kBroadcast`` ops are processed in **every** phase (not restricted to one):
they can fuse into a post-dominator that is injective or reduction. The sink (final node) may
also be a ``kOutEWiseFusable`` group that was formed in phase 0 — this is how elementwise
producers merge into an existing conv2d fusion group.

Additional constraints:

- **Reduction** (``kCommReduce``) ops never initiate fusion — they act as sinks only. Elementwise
and broadcast producers can fuse *into* a reduction, but a reduction cannot fuse forward.
- **Opaque** ops are fusion barriers.
- A group cannot exceed ``kMaxFusedOps`` (256) nodes or the maximum function argument count.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

It would be beneficial to mention that the kMaxFusedOps limit is configurable via the relax.FuseOps.max_depth pass configuration option, as this provides users with control over the fusion granularity.


Example
~~~~~~~

Given two elementwise ops (``add``, ``exp``) and one injective op (``squeeze``).
The examples below are simplified pseudocode — real TVMScript would reference TIR functions
via ``cls.func_name``:

.. code-block:: python

# Before FuseOps (simplified)
@R.function
def main(x: R.Tensor((10, 20), "float32")):
with R.dataflow():
lv0 = R.call_tir(add, (x, const_1), out_sinfo=R.Tensor((10, 20), "float32"))
lv1 = R.call_tir(exp, (lv0,), out_sinfo=R.Tensor((10, 20), "float32"))
gv = R.call_tir(squeeze, (lv1,), out_sinfo=R.Tensor((10, 20), "float32"))
R.output(gv)
return gv

After ``FuseOps``, all three are grouped into a single function:

.. code-block:: python

# After FuseOps
@R.function(private=True)
def fused_add_exp_squeeze(x, p0):
R.func_attr({"Primitive": True})
with R.dataflow():
lv0 = R.call_tir(add, (x, p0), ...)
lv1 = R.call_tir(exp, (lv0,), ...)
gv = R.call_tir(squeeze, (lv1,), ...)
R.output(gv)
return gv

@R.function
def main(x: R.Tensor((10, 20), "float32")):
with R.dataflow():
gv = fused_add_exp_squeeze(x, const_1)
R.output(gv)
return gv


FuseTIR: Merging TIR Functions
------------------------------

``FuseTIR`` (``src/relax/transform/fuse_tir.cc``) takes the grouped Relax functions produced by
``FuseOps`` and merges their internal TIR ``PrimFunc``\ s into a single TIR function.

Before ``FuseTIR``, a fused group still contains multiple ``R.call_tir`` calls to separate
TIR functions. ``FuseTIR`` inlines and merges them:

.. code-block:: text

Before FuseTIR:
fused_add_exp_squeeze:
call_tir(add, ...) → separate TIR PrimFunc
call_tir(exp, ...) → separate TIR PrimFunc
call_tir(squeeze, ...) → separate TIR PrimFunc

After FuseTIR:
fused_add_exp_squeeze: → single merged TIR PrimFunc

The merged function eliminates intermediate buffers — the output of ``add`` is directly consumed
by ``exp`` without writing to and reading from global memory. This is the core performance benefit
of fusion.

Internally, ``FuseTIR`` uses a ``SymbolicMatcher`` to align symbolic shape variables across the
TIR functions being merged, ensuring that dimensions are correctly mapped when combining buffer
accesses.


FuseOpsByPattern: Pattern-Based Fusion
--------------------------------------

While ``FuseOps`` makes fusion decisions automatically based on operator patterns,
``FuseOpsByPattern`` lets you specify exactly which operator combinations to fuse using
the Relax :ref:`Dataflow Pattern Language (DPL) <relax-dpl>`.

This is primarily used for **backend-specific dispatch**: identifying operator subgraphs that
should be offloaded to external libraries like cuBLAS, CUTLASS, cuDNN, or DNNL.

FusionPattern
~~~~~~~~~~~~~

A ``FusionPattern`` (``python/tvm/relax/transform/transform.py``) defines what to match:

.. code-block:: python

from tvm.relax.dpl import wildcard, is_op
from tvm.relax.transform import FusionPattern

# Match: matmul(x, w) + bias
x = wildcard()
w = wildcard()
bias = wildcard()
matmul = is_op("relax.matmul")(x, w)
out = is_op("relax.add")(matmul, bias)

pattern = FusionPattern(
name="cutlass.matmul_bias",
pattern=out,
annotation_patterns={"matmul": matmul, "bias": bias},
check=my_check_function, # optional validation
)

Fields:

- ``name``: pattern identifier, typically prefixed with the backend name (e.g.,
``"cutlass.matmul_bias"``).
- ``pattern``: a DFPattern describing the subgraph to match. See the
:ref:`DPL deep dive <relax-dpl>` for the full pattern language.
- ``annotation_patterns``: a mapping of names to sub-patterns within the main pattern. These
are extracted during matching and made available to the ``check`` function and
``attrs_getter``.
- ``check``: an optional ``Callable[[PatternCheckContext], bool]`` that validates whether
a match should be accepted. Receives the matched expression, annotated sub-expressions,
variable usages, and binding information.
- ``attrs_getter``: an optional function that extracts attributes (e.g., transpose flags,
data types) from the matched expressions to annotate the grouped function.

Applying patterns
~~~~~~~~~~~~~~~~~

.. code-block:: python

from tvm.relax.transform import FuseOpsByPattern

mod = FuseOpsByPattern(
patterns=[pattern1, pattern2, ...], # ordered by priority
bind_constants=True,
annotate_codegen=False,
)(mod)

Key parameters:

- ``patterns``: a list of ``FusionPattern`` objects, ordered by priority. Higher-priority
patterns come first — if a subgraph matches multiple patterns, the first match wins.
- ``bind_constants``: if ``True``, constants used by the matched subgraph are captured inside
the grouped function.
- ``annotate_codegen``: if ``True``, wraps each composite function with an outer function
annotated with ``"Codegen"`` and ``"global_symbol"`` attributes for external backend dispatch.
The ``"Codegen"`` value is derived from the pattern name prefix (e.g., ``"dnnl"`` from
``"dnnl.conv2d_relu"``).

PatternCheckContext
~~~~~~~~~~~~~~~~~~~

The ``check`` function receives a ``PatternCheckContext`` with:

- ``matched_expr``: the root expression matched by the pattern.
- ``annotated_expr``: a mapping from annotation pattern names to their matched expressions.
- ``matched_bindings``: variable-to-value bindings within the matched subgraph.
- ``var_usages``: a mapping from variable definitions to all their uses in the function.
- ``value_to_bound_var``: reverse mapping from values to the variables they are bound to.

This context enables sophisticated validation logic, such as checking that an intermediate
result is not used outside the fused group, or verifying data type compatibility.


How Backends Use Fusion
-----------------------

The default backend pipelines (CUDA, ROCm, CPU, etc.) all include ``FuseOps`` + ``FuseTIR``
in their ``legalize_passes`` phase for automatic fusion. For example, the CUDA pipeline
(``python/tvm/relax/backend/cuda/pipeline.py``) runs::

LegalizeOps → AnnotateTIROpPattern → FoldConstant → FuseOps → FuseTIR → DLight

For external library dispatch (cuBLAS, CUTLASS, cuDNN, DNNL), ``FuseOpsByPattern`` is used
separately. These are **not** included in the default pipeline — users add them explicitly
when building a custom compilation flow. The typical sequence is:

1. **Pattern-based dispatch** (``FuseOpsByPattern``): identify subgraphs that should be
offloaded to external libraries. For example, CUTLASS patterns match
matmul+bias+activation combinations (``python/tvm/relax/backend/cuda/cutlass.py``).
Functions marked by patterns are annotated with ``Composite`` and optionally ``Codegen``
attributes.

2. **Automatic fusion** (``FuseOps`` + ``FuseTIR``): remaining operators that were not
matched by backend patterns are fused automatically based on their pattern kinds.


Source Code Map
---------------

.. list-table::
:header-rows: 1
:widths: 50 50

* - Path
- Contents
* - ``src/relax/transform/fuse_ops.cc``
- FuseOps and FuseOpsByPattern implementation
* - ``src/relax/analysis/graph_partitioner.h``
- IndexedForwardGraph, DominatorTree, GraphPartitioner (Union-Find)
* - ``src/relax/transform/fuse_tir.cc``
- FuseTIR implementation, SymbolicMatcher
* - ``include/tvm/relax/op_attr_types.h``
- ``OpPatternKind`` enum definition
* - ``python/tvm/relax/transform/transform.py``
- Python API: FuseOps, FuseTIR, FuseOpsByPattern, FusionPattern
* - ``python/tvm/relax/dpl/``
- Dataflow Pattern Language (DFPattern, is_op, wildcard, etc.)
* - ``python/tvm/relax/backend/cuda/cutlass.py``
- Example: CUTLASS fusion patterns
* - ``python/tvm/relax/backend/cuda/cublas.py``
- Example: cuBLAS fusion patterns
Loading
Loading