Skip to content

Commit e22be4e

Browse files
authored
[Docs] Add BYOC external library dispatch architecture documentation (#19395)
Add an architecture document (`docs/arch/external_library_dispatch.rst`) for TVM's BYOC (Bring Your Own Codegen) mechanism, covering the end-to-end pipeline for offloading operator subgraphs to external libraries (cuBLAS, CUTLASS, cuDNN, DNNL)
1 parent e5d4c55 commit e22be4e

2 files changed

Lines changed: 368 additions & 0 deletions

File tree

Lines changed: 361 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,361 @@
1+
.. Licensed to the Apache Software Foundation (ASF) under one
2+
or more contributor license agreements. See the NOTICE file
3+
distributed with this work for additional information
4+
regarding copyright ownership. The ASF licenses this file
5+
to you under the Apache License, Version 2.0 (the
6+
"License"); you may not use this file except in compliance
7+
with the License. You may obtain a copy of the License at
8+
9+
.. http://www.apache.org/licenses/LICENSE-2.0
10+
11+
.. Unless required by applicable law or agreed to in writing,
12+
software distributed under the License is distributed on an
13+
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
KIND, either express or implied. See the License for the
15+
specific language governing permissions and limitations
16+
under the License.
17+
18+
.. _external-library-dispatch:
19+
20+
External Library Dispatch (BYOC)
21+
================================
22+
23+
When deploying models, certain operator patterns (e.g., matmul + bias + relu) can be executed
24+
more efficiently by vendor-optimized libraries such as cuBLAS, CUTLASS, cuDNN, or DNNL. TVM's
25+
**BYOC (Bring Your Own Codegen)** mechanism identifies these patterns in a Relax module and
26+
offloads them to external backends, while keeping the rest of the computation on TVM's own
27+
generated kernels.
28+
29+
This document explains the BYOC pipeline: how patterns are registered, how subgraphs are
30+
matched and extracted, how backend code generators are invoked, and how the externally compiled
31+
code is executed at runtime.
32+
33+
34+
Overview
35+
--------
36+
37+
The BYOC pipeline consists of four stages:
38+
39+
.. code-block:: text
40+
41+
IRModule (high-level Relax IR)
42+
43+
▼ FuseOpsByPattern ← match high-level ops, create composite functions
44+
IRModule (with Composite + Codegen attributes)
45+
46+
▼ RunCodegen ← invoke backend codegen via FFI
47+
IRModule (with call_dps_packed to ExternFunc)
48+
+ external runtime Modules
49+
50+
▼ LegalizeOps + FuseOps + ... ← compile remaining ops normally
51+
52+
▼ VM compilation ← link external modules into executable
53+
Deployable artifact
54+
55+
Each stage is a Relax transformation pass that operates on the ``IRModule``:
56+
57+
1. **FuseOpsByPattern** — matches operator subgraphs against registered patterns and groups them
58+
into composite functions annotated with ``Composite`` and ``Codegen`` attributes.
59+
2. **MergeCompositeFunctions** (optional) — merges multiple composite functions targeting the same
60+
backend when inter-operator dependencies allow.
61+
3. **RunCodegen** — finds all functions with a ``Codegen`` attribute, invokes the corresponding
62+
backend code generator via FFI, and replaces the original calls with ``call_dps_packed``
63+
to externally compiled functions.
64+
4. **Linking** — the resulting external ``runtime.Module``\ s are attached to the ``IRModule``
65+
as the ``external_mods`` attribute and bundled into the final executable during
66+
``relax.build()``.
67+
68+
69+
Pattern Registration
70+
--------------------
71+
72+
Each backend registers the operator patterns it supports in a **global pattern registry**
73+
(``python/tvm/relax/backend/pattern_registry.py``). The registry is a static table that maps
74+
pattern names to ``FusionPattern`` objects.
75+
76+
Registering patterns
77+
~~~~~~~~~~~~~~~~~~~~
78+
79+
.. code-block:: python
80+
81+
from tvm.relax.backend.pattern_registry import register_patterns
82+
from tvm.relax.backend.patterns import make_matmul_pattern
83+
84+
register_patterns([
85+
(
86+
"cublas.matmul", # pattern name (prefix = backend)
87+
*make_matmul_pattern( # returns (DFPattern, annotation_patterns)
88+
with_bias=False,
89+
),
90+
_check_matmul, # check function
91+
),
92+
(
93+
"cublas.matmul_bias_relu",
94+
*make_matmul_pattern(
95+
with_bias=True,
96+
activation="relax.nn.relu",
97+
),
98+
_check_matmul,
99+
),
100+
# ... more patterns
101+
])
102+
103+
Each entry is a tuple of ``(name, pattern, annotation_patterns, check_func)`` that gets
104+
converted to a ``FusionPattern`` object. The name prefix (e.g., ``"cublas"``) identifies the
105+
backend; ``get_patterns_with_prefix("cublas")`` retrieves all patterns for that backend.
106+
107+
Patterns registered later have **higher priority** — when a subgraph matches multiple patterns,
108+
the highest-priority match wins.
109+
110+
Pattern templates
111+
~~~~~~~~~~~~~~~~~
112+
113+
``python/tvm/relax/backend/patterns.py`` provides reusable templates for common patterns:
114+
115+
- ``make_matmul_pattern(with_bias, activation, transposed_rhs)`` — matmul with optional bias
116+
and activation fusion
117+
- ``make_conv2d_pattern(with_bias, activation)`` — 2D convolution
118+
- ``make_attention_pattern()`` — multi-head attention
119+
- ``make_residual_block_pattern()`` — residual connections
120+
- ``make_layer_norm_pattern()`` / ``make_rms_norm_pattern()`` — normalization layers
121+
122+
Each template returns ``(DFPattern, Mapping[str, DFPattern])`` — the main pattern and its
123+
annotation sub-patterns.
124+
125+
Check functions
126+
~~~~~~~~~~~~~~~
127+
128+
The check function validates whether a matched subgraph can actually be handled by the backend.
129+
It receives a ``PatternCheckContext`` and returns ``True`` to accept or ``False`` to reject.
130+
131+
Typical checks include:
132+
133+
- **Data type support**: verify the operand dtypes are supported (e.g., cuBLAS supports
134+
float16, float32, int8, bfloat16, float8 for matmul).
135+
- **Shape constraints**: verify reduction axes are constant, batch dimensions are compatible.
136+
- **Leaking intermediates**: reject if an intermediate result is used outside the fused group
137+
(via ``has_leaking_intermediate_variables()``).
138+
139+
140+
Partitioning
141+
------------
142+
143+
After patterns are registered, a backend provides a **partition function** that applies
144+
``FuseOpsByPattern`` to an ``IRModule``:
145+
146+
.. code-block:: python
147+
148+
# python/tvm/relax/backend/cuda/cublas.py
149+
def partition_for_cublas(mod, bind_constants=False):
150+
patterns = get_patterns_with_prefix("cublas")
151+
return transform.FuseOpsByPattern(
152+
patterns, bind_constants=bind_constants, annotate_codegen=True
153+
)(mod)
154+
155+
With ``annotate_codegen=True``, each matched subgraph is wrapped in a two-level function
156+
structure:
157+
158+
.. code-block:: text
159+
160+
# Outer function — tagged for the codegen backend
161+
@R.function
162+
def fused_relax_matmul_cublas0(args...):
163+
R.func_attr({"Codegen": "cublas", "global_symbol": "fused_relax_matmul_cublas0"})
164+
...
165+
# Inner function — identifies the specific pattern
166+
@R.function(private=True)
167+
def composite(args...):
168+
R.func_attr({"Composite": "cublas.matmul_bias_relu"})
169+
lv0 = R.matmul(x, w)
170+
lv1 = R.add(lv0, bias)
171+
lv2 = R.nn.relu(lv1)
172+
return lv2
173+
...
174+
175+
The outer function carries the ``Codegen`` attribute that ``RunCodegen`` uses to dispatch to the
176+
right backend. The inner function carries the ``Composite`` attribute that the backend codegen
177+
uses to identify which operation to emit.
178+
179+
MergeCompositeFunctions
180+
~~~~~~~~~~~~~~~~~~~~~~~
181+
182+
When ``annotate_codegen=False``, ``FuseOpsByPattern`` only creates inner functions with
183+
``Composite`` attributes. A separate ``MergeCompositeFunctions`` pass then groups multiple
184+
composite functions targeting the same backend into a single outer function with ``Codegen``
185+
and ``global_symbol`` attributes.
186+
187+
This is useful when multiple sequential operations should be sent to the same backend as a
188+
single unit (e.g., a sequence of cuBLAS matmuls that share intermediate results). The pass
189+
checks that merging does not create cyclic dependencies between groups.
190+
191+
192+
Code Generation
193+
---------------
194+
195+
``RunCodegen`` (``src/relax/transform/run_codegen.cc``) is the pass that triggers backend
196+
code generation:
197+
198+
1. Scan the module for all functions with a ``Codegen`` attribute.
199+
2. Group them by backend target name.
200+
3. For each backend, look up the registered codegen function via FFI key
201+
``"relax.ext.<backend>"`` (e.g., ``"relax.ext.cublas"``).
202+
4. Call the codegen function, which returns an array of compiled ``runtime.Module``\ s.
203+
5. Replace the original function calls with ``call_dps_packed(ExternFunc(...), args)``.
204+
6. Attach the compiled modules to the ``IRModule`` as the ``external_mods`` attribute.
205+
206+
Codegen registration
207+
~~~~~~~~~~~~~~~~~~~~
208+
209+
Each backend registers a codegen function via TVM's FFI mechanism:
210+
211+
.. code-block:: cpp
212+
213+
// src/relax/backend/contrib/cublas/codegen.cc
214+
ffi::Array<ffi::Module> CublasCompiler(
215+
ffi::Array<Function> functions,
216+
ffi::Map<ffi::String, ffi::Any> options,
217+
ffi::Map<Constant, ffi::String> constant_names) {
218+
ffi::Array<ffi::Module> compiled_functions;
219+
for (const auto& func : functions) {
220+
CublasJSONSerializer serializer(constant_names, AnalyzeVar2Value(func));
221+
serializer.serialize(func);
222+
auto graph_json = serializer.GetJSON();
223+
auto names = serializer.GetConstantNames();
224+
const auto pf = ffi::Function::GetGlobalRequired("runtime.CublasJSONRuntimeCreate");
225+
compiled_functions.push_back(
226+
pf(GetExtSymbol(func), graph_json, names).cast<ffi::Module>());
227+
}
228+
return compiled_functions;
229+
}
230+
231+
TVM_FFI_STATIC_INIT_BLOCK() {
232+
namespace refl = tvm::ffi::reflection;
233+
refl::GlobalDef().def("relax.ext.cublas", CublasCompiler);
234+
}
235+
236+
The codegen function receives:
237+
238+
- ``functions``: the Relax functions with ``Codegen`` attribute to compile.
239+
- ``options``: backend-specific compilation options.
240+
- ``constant_names``: mapping from constant values to their names (for weight handling).
241+
242+
It returns an array of ``runtime.Module`` objects — one per function — that contain the
243+
externally compiled code.
244+
245+
Codegen strategies
246+
~~~~~~~~~~~~~~~~~~
247+
248+
TVM provides two base classes for implementing backend codegens:
249+
250+
- **JSONSerializer** (``src/relax/backend/contrib/codegen_json/codegen_json.h``): serializes the
251+
composite function into a JSON graph representation. At runtime, a backend-specific JSON
252+
runtime module interprets the graph and dispatches to library calls. Used by cuBLAS, cuDNN,
253+
and most backends.
254+
255+
- **CSourceCodegen** (``src/relax/backend/contrib/codegen_c/codegen_c.h``): generates C/CUDA
256+
source code that is compiled and linked. Used when the backend requires ahead-of-time
257+
compilation.
258+
259+
260+
Runtime Execution
261+
-----------------
262+
263+
After ``RunCodegen``, the original high-level function calls are replaced with:
264+
265+
.. code-block:: python
266+
267+
R.call_dps_packed(ExternFunc("fused_relax_matmul_cublas0"), (x, w, bias), ...)
268+
269+
At runtime, ``call_dps_packed`` invokes the externally compiled function through the
270+
``PackedFunc`` interface. The external ``runtime.Module``\ s (produced by the codegen) are
271+
imported into the final executable during ``relax.build()`` and are available via the module's
272+
function lookup mechanism.
273+
274+
For JSON-based backends (cuBLAS, cuDNN), the runtime module deserializes the JSON graph and
275+
dispatches each node to the corresponding library API call. For source-based backends, the
276+
compiled native code is called directly.
277+
278+
279+
Adding a New Backend
280+
--------------------
281+
282+
To add support for a new external library:
283+
284+
1. **Define patterns** in ``python/tvm/relax/backend/<target>/``:
285+
286+
- Create DFPatterns using templates from ``patterns.py`` or custom patterns.
287+
- Write check functions to validate dtypes, shapes, and other constraints.
288+
- Register patterns with ``register_patterns()``.
289+
- Provide a ``partition_for_<backend>(mod)`` convenience function.
290+
291+
2. **Implement codegen** in ``src/relax/backend/contrib/<target>/``:
292+
293+
- Subclass ``JSONSerializer`` or ``CSourceCodegen``.
294+
- Implement the visitor that converts composite functions to the target format.
295+
- Register the codegen function as ``"relax.ext.<target>"``.
296+
297+
3. **Implement runtime** (for JSON-based backends):
298+
299+
- Create a JSON runtime module that interprets the serialized graph and dispatches
300+
to the library's API calls.
301+
- Register the runtime constructor as ``"runtime.<Target>JSONRuntimeCreate"``.
302+
303+
304+
Supported Backends
305+
------------------
306+
307+
.. list-table::
308+
:header-rows: 1
309+
:widths: 15 25 60
310+
311+
* - Backend
312+
- Patterns
313+
- Operations
314+
* - cuBLAS
315+
- ``cublas.*``
316+
- Matmul (with bias, activation, transpose, dequantize variants)
317+
* - CUTLASS
318+
- ``cutlass.*``
319+
- Matmul, conv2d, attention, residual blocks, decode matmul
320+
* - cuDNN
321+
- ``cudnn.*``
322+
- Conv2d (NHWC/NCHW), stacked attention
323+
* - DNNL
324+
- ``dnnl.*``
325+
- Matmul, conv2d (x86 CPU). Codegen exists at C++ level; patterns are
326+
defined in tests rather than pre-registered.
327+
328+
329+
Source Code Map
330+
---------------
331+
332+
.. list-table::
333+
:header-rows: 1
334+
:widths: 50 50
335+
336+
* - Path
337+
- Contents
338+
* - ``python/tvm/relax/backend/pattern_registry.py``
339+
- Pattern registry API (register_patterns, get_patterns_with_prefix)
340+
* - ``python/tvm/relax/backend/patterns.py``
341+
- Reusable pattern templates (make_matmul_pattern, etc.)
342+
* - ``python/tvm/relax/backend/cuda/cublas.py``
343+
- cuBLAS patterns and partition_for_cublas
344+
* - ``python/tvm/relax/backend/cuda/cutlass.py``
345+
- CUTLASS patterns and partition_for_cutlass
346+
* - ``python/tvm/relax/backend/cuda/cudnn.py``
347+
- cuDNN patterns and partition_for_cudnn
348+
* - ``src/relax/backend/pattern_registry.cc``
349+
- Pattern registry C++ implementation
350+
* - ``src/relax/transform/run_codegen.cc``
351+
- RunCodegen pass (CodeGenRunner)
352+
* - ``src/relax/transform/merge_composite_functions.cc``
353+
- MergeCompositeFunctions pass
354+
* - ``src/relax/backend/contrib/cublas/codegen.cc``
355+
- cuBLAS codegen (JSONSerializer-based)
356+
* - ``src/relax/backend/contrib/cutlass/codegen.cc``
357+
- CUTLASS codegen
358+
* - ``src/relax/backend/contrib/codegen_json/codegen_json.h``
359+
- JSONSerializer base class
360+
* - ``src/relax/backend/contrib/codegen_c/codegen_c.h``
361+
- CSourceCodegen base class

docs/arch/index.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,9 +116,16 @@ The target translation phase transforms an IRModule to the corresponding target
116116
For backends such as x86 and ARM, we use the LLVM IRBuilder to build in-memory LLVM IR.
117117
We can also generate source-level languages such as CUDA C and OpenCL.
118118
Finally, we support direct translations of a Relax function (sub-graph) to specific targets via external code generators.
119+
See :ref:`external-library-dispatch` for the full BYOC (Bring Your Own Codegen) pipeline that
120+
offloads operator subgraphs to vendor libraries like cuBLAS, CUTLASS, and cuDNN.
119121
It is important that the final code generation phase is as lightweight as possible. Vast majority of transformations
120122
and lowering should be performed before the target translation phase.
121123

124+
.. toctree::
125+
:maxdepth: 1
126+
127+
external_library_dispatch
128+
122129
We also provide a Target structure to specify the compilation target.
123130
The transformations before the target translation phase can also be affected by the target — for example,
124131
a target's vector length would change the vectorization behavior.

0 commit comments

Comments
 (0)