|
| 1 | +.. Licensed to the Apache Software Foundation (ASF) under one |
| 2 | + or more contributor license agreements. See the NOTICE file |
| 3 | + distributed with this work for additional information |
| 4 | + regarding copyright ownership. The ASF licenses this file |
| 5 | + to you under the Apache License, Version 2.0 (the |
| 6 | + "License"); you may not use this file except in compliance |
| 7 | + with the License. You may obtain a copy of the License at |
| 8 | +
|
| 9 | +.. http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | +
|
| 11 | +.. Unless required by applicable law or agreed to in writing, |
| 12 | + software distributed under the License is distributed on an |
| 13 | + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| 14 | + KIND, either express or implied. See the License for the |
| 15 | + specific language governing permissions and limitations |
| 16 | + under the License. |
| 17 | +
|
| 18 | +.. _external-library-dispatch: |
| 19 | + |
| 20 | +External Library Dispatch (BYOC) |
| 21 | +================================ |
| 22 | + |
| 23 | +When deploying models, certain operator patterns (e.g., matmul + bias + relu) can be executed |
| 24 | +more efficiently by vendor-optimized libraries such as cuBLAS, CUTLASS, cuDNN, or DNNL. TVM's |
| 25 | +**BYOC (Bring Your Own Codegen)** mechanism identifies these patterns in a Relax module and |
| 26 | +offloads them to external backends, while keeping the rest of the computation on TVM's own |
| 27 | +generated kernels. |
| 28 | + |
| 29 | +This document explains the BYOC pipeline: how patterns are registered, how subgraphs are |
| 30 | +matched and extracted, how backend code generators are invoked, and how the externally compiled |
| 31 | +code is executed at runtime. |
| 32 | + |
| 33 | + |
| 34 | +Overview |
| 35 | +-------- |
| 36 | + |
| 37 | +The BYOC pipeline consists of four stages: |
| 38 | + |
| 39 | +.. code-block:: text |
| 40 | +
|
| 41 | + IRModule (high-level Relax IR) |
| 42 | + │ |
| 43 | + ▼ FuseOpsByPattern ← match high-level ops, create composite functions |
| 44 | + IRModule (with Composite + Codegen attributes) |
| 45 | + │ |
| 46 | + ▼ RunCodegen ← invoke backend codegen via FFI |
| 47 | + IRModule (with call_dps_packed to ExternFunc) |
| 48 | + + external runtime Modules |
| 49 | + │ |
| 50 | + ▼ LegalizeOps + FuseOps + ... ← compile remaining ops normally |
| 51 | + │ |
| 52 | + ▼ VM compilation ← link external modules into executable |
| 53 | + Deployable artifact |
| 54 | +
|
| 55 | +Each stage is a Relax transformation pass that operates on the ``IRModule``: |
| 56 | + |
| 57 | +1. **FuseOpsByPattern** — matches operator subgraphs against registered patterns and groups them |
| 58 | + into composite functions annotated with ``Composite`` and ``Codegen`` attributes. |
| 59 | +2. **MergeCompositeFunctions** (optional) — merges multiple composite functions targeting the same |
| 60 | + backend when inter-operator dependencies allow. |
| 61 | +3. **RunCodegen** — finds all functions with a ``Codegen`` attribute, invokes the corresponding |
| 62 | + backend code generator via FFI, and replaces the original calls with ``call_dps_packed`` |
| 63 | + to externally compiled functions. |
| 64 | +4. **Linking** — the resulting external ``runtime.Module``\ s are attached to the ``IRModule`` |
| 65 | + as the ``external_mods`` attribute and bundled into the final executable during |
| 66 | + ``relax.build()``. |
| 67 | + |
| 68 | + |
| 69 | +Pattern Registration |
| 70 | +-------------------- |
| 71 | + |
| 72 | +Each backend registers the operator patterns it supports in a **global pattern registry** |
| 73 | +(``python/tvm/relax/backend/pattern_registry.py``). The registry is a static table that maps |
| 74 | +pattern names to ``FusionPattern`` objects. |
| 75 | + |
| 76 | +Registering patterns |
| 77 | +~~~~~~~~~~~~~~~~~~~~ |
| 78 | + |
| 79 | +.. code-block:: python |
| 80 | +
|
| 81 | + from tvm.relax.backend.pattern_registry import register_patterns |
| 82 | + from tvm.relax.backend.patterns import make_matmul_pattern |
| 83 | +
|
| 84 | + register_patterns([ |
| 85 | + ( |
| 86 | + "cublas.matmul", # pattern name (prefix = backend) |
| 87 | + *make_matmul_pattern( # returns (DFPattern, annotation_patterns) |
| 88 | + with_bias=False, |
| 89 | + ), |
| 90 | + _check_matmul, # check function |
| 91 | + ), |
| 92 | + ( |
| 93 | + "cublas.matmul_bias_relu", |
| 94 | + *make_matmul_pattern( |
| 95 | + with_bias=True, |
| 96 | + activation="relax.nn.relu", |
| 97 | + ), |
| 98 | + _check_matmul, |
| 99 | + ), |
| 100 | + # ... more patterns |
| 101 | + ]) |
| 102 | +
|
| 103 | +Each entry is a tuple of ``(name, pattern, annotation_patterns, check_func)`` that gets |
| 104 | +converted to a ``FusionPattern`` object. The name prefix (e.g., ``"cublas"``) identifies the |
| 105 | +backend; ``get_patterns_with_prefix("cublas")`` retrieves all patterns for that backend. |
| 106 | + |
| 107 | +Patterns registered later have **higher priority** — when a subgraph matches multiple patterns, |
| 108 | +the highest-priority match wins. |
| 109 | + |
| 110 | +Pattern templates |
| 111 | +~~~~~~~~~~~~~~~~~ |
| 112 | + |
| 113 | +``python/tvm/relax/backend/patterns.py`` provides reusable templates for common patterns: |
| 114 | + |
| 115 | +- ``make_matmul_pattern(with_bias, activation, transposed_rhs)`` — matmul with optional bias |
| 116 | + and activation fusion |
| 117 | +- ``make_conv2d_pattern(with_bias, activation)`` — 2D convolution |
| 118 | +- ``make_attention_pattern()`` — multi-head attention |
| 119 | +- ``make_residual_block_pattern()`` — residual connections |
| 120 | +- ``make_layer_norm_pattern()`` / ``make_rms_norm_pattern()`` — normalization layers |
| 121 | + |
| 122 | +Each template returns ``(DFPattern, Mapping[str, DFPattern])`` — the main pattern and its |
| 123 | +annotation sub-patterns. |
| 124 | + |
| 125 | +Check functions |
| 126 | +~~~~~~~~~~~~~~~ |
| 127 | + |
| 128 | +The check function validates whether a matched subgraph can actually be handled by the backend. |
| 129 | +It receives a ``PatternCheckContext`` and returns ``True`` to accept or ``False`` to reject. |
| 130 | + |
| 131 | +Typical checks include: |
| 132 | + |
| 133 | +- **Data type support**: verify the operand dtypes are supported (e.g., cuBLAS supports |
| 134 | + float16, float32, int8, bfloat16, float8 for matmul). |
| 135 | +- **Shape constraints**: verify reduction axes are constant, batch dimensions are compatible. |
| 136 | +- **Leaking intermediates**: reject if an intermediate result is used outside the fused group |
| 137 | + (via ``has_leaking_intermediate_variables()``). |
| 138 | + |
| 139 | + |
| 140 | +Partitioning |
| 141 | +------------ |
| 142 | + |
| 143 | +After patterns are registered, a backend provides a **partition function** that applies |
| 144 | +``FuseOpsByPattern`` to an ``IRModule``: |
| 145 | + |
| 146 | +.. code-block:: python |
| 147 | +
|
| 148 | + # python/tvm/relax/backend/cuda/cublas.py |
| 149 | + def partition_for_cublas(mod, bind_constants=False): |
| 150 | + patterns = get_patterns_with_prefix("cublas") |
| 151 | + return transform.FuseOpsByPattern( |
| 152 | + patterns, bind_constants=bind_constants, annotate_codegen=True |
| 153 | + )(mod) |
| 154 | +
|
| 155 | +With ``annotate_codegen=True``, each matched subgraph is wrapped in a two-level function |
| 156 | +structure: |
| 157 | + |
| 158 | +.. code-block:: text |
| 159 | +
|
| 160 | + # Outer function — tagged for the codegen backend |
| 161 | + @R.function |
| 162 | + def fused_relax_matmul_cublas0(args...): |
| 163 | + R.func_attr({"Codegen": "cublas", "global_symbol": "fused_relax_matmul_cublas0"}) |
| 164 | + ... |
| 165 | + # Inner function — identifies the specific pattern |
| 166 | + @R.function(private=True) |
| 167 | + def composite(args...): |
| 168 | + R.func_attr({"Composite": "cublas.matmul_bias_relu"}) |
| 169 | + lv0 = R.matmul(x, w) |
| 170 | + lv1 = R.add(lv0, bias) |
| 171 | + lv2 = R.nn.relu(lv1) |
| 172 | + return lv2 |
| 173 | + ... |
| 174 | +
|
| 175 | +The outer function carries the ``Codegen`` attribute that ``RunCodegen`` uses to dispatch to the |
| 176 | +right backend. The inner function carries the ``Composite`` attribute that the backend codegen |
| 177 | +uses to identify which operation to emit. |
| 178 | + |
| 179 | +MergeCompositeFunctions |
| 180 | +~~~~~~~~~~~~~~~~~~~~~~~ |
| 181 | + |
| 182 | +When ``annotate_codegen=False``, ``FuseOpsByPattern`` only creates inner functions with |
| 183 | +``Composite`` attributes. A separate ``MergeCompositeFunctions`` pass then groups multiple |
| 184 | +composite functions targeting the same backend into a single outer function with ``Codegen`` |
| 185 | +and ``global_symbol`` attributes. |
| 186 | + |
| 187 | +This is useful when multiple sequential operations should be sent to the same backend as a |
| 188 | +single unit (e.g., a sequence of cuBLAS matmuls that share intermediate results). The pass |
| 189 | +checks that merging does not create cyclic dependencies between groups. |
| 190 | + |
| 191 | + |
| 192 | +Code Generation |
| 193 | +--------------- |
| 194 | + |
| 195 | +``RunCodegen`` (``src/relax/transform/run_codegen.cc``) is the pass that triggers backend |
| 196 | +code generation: |
| 197 | + |
| 198 | +1. Scan the module for all functions with a ``Codegen`` attribute. |
| 199 | +2. Group them by backend target name. |
| 200 | +3. For each backend, look up the registered codegen function via FFI key |
| 201 | + ``"relax.ext.<backend>"`` (e.g., ``"relax.ext.cublas"``). |
| 202 | +4. Call the codegen function, which returns an array of compiled ``runtime.Module``\ s. |
| 203 | +5. Replace the original function calls with ``call_dps_packed(ExternFunc(...), args)``. |
| 204 | +6. Attach the compiled modules to the ``IRModule`` as the ``external_mods`` attribute. |
| 205 | + |
| 206 | +Codegen registration |
| 207 | +~~~~~~~~~~~~~~~~~~~~ |
| 208 | + |
| 209 | +Each backend registers a codegen function via TVM's FFI mechanism: |
| 210 | + |
| 211 | +.. code-block:: cpp |
| 212 | +
|
| 213 | + // src/relax/backend/contrib/cublas/codegen.cc |
| 214 | + ffi::Array<ffi::Module> CublasCompiler( |
| 215 | + ffi::Array<Function> functions, |
| 216 | + ffi::Map<ffi::String, ffi::Any> options, |
| 217 | + ffi::Map<Constant, ffi::String> constant_names) { |
| 218 | + ffi::Array<ffi::Module> compiled_functions; |
| 219 | + for (const auto& func : functions) { |
| 220 | + CublasJSONSerializer serializer(constant_names, AnalyzeVar2Value(func)); |
| 221 | + serializer.serialize(func); |
| 222 | + auto graph_json = serializer.GetJSON(); |
| 223 | + auto names = serializer.GetConstantNames(); |
| 224 | + const auto pf = ffi::Function::GetGlobalRequired("runtime.CublasJSONRuntimeCreate"); |
| 225 | + compiled_functions.push_back( |
| 226 | + pf(GetExtSymbol(func), graph_json, names).cast<ffi::Module>()); |
| 227 | + } |
| 228 | + return compiled_functions; |
| 229 | + } |
| 230 | +
|
| 231 | + TVM_FFI_STATIC_INIT_BLOCK() { |
| 232 | + namespace refl = tvm::ffi::reflection; |
| 233 | + refl::GlobalDef().def("relax.ext.cublas", CublasCompiler); |
| 234 | + } |
| 235 | +
|
| 236 | +The codegen function receives: |
| 237 | + |
| 238 | +- ``functions``: the Relax functions with ``Codegen`` attribute to compile. |
| 239 | +- ``options``: backend-specific compilation options. |
| 240 | +- ``constant_names``: mapping from constant values to their names (for weight handling). |
| 241 | + |
| 242 | +It returns an array of ``runtime.Module`` objects — one per function — that contain the |
| 243 | +externally compiled code. |
| 244 | + |
| 245 | +Codegen strategies |
| 246 | +~~~~~~~~~~~~~~~~~~ |
| 247 | + |
| 248 | +TVM provides two base classes for implementing backend codegens: |
| 249 | + |
| 250 | +- **JSONSerializer** (``src/relax/backend/contrib/codegen_json/codegen_json.h``): serializes the |
| 251 | + composite function into a JSON graph representation. At runtime, a backend-specific JSON |
| 252 | + runtime module interprets the graph and dispatches to library calls. Used by cuBLAS, cuDNN, |
| 253 | + and most backends. |
| 254 | + |
| 255 | +- **CSourceCodegen** (``src/relax/backend/contrib/codegen_c/codegen_c.h``): generates C/CUDA |
| 256 | + source code that is compiled and linked. Used when the backend requires ahead-of-time |
| 257 | + compilation. |
| 258 | + |
| 259 | + |
| 260 | +Runtime Execution |
| 261 | +----------------- |
| 262 | + |
| 263 | +After ``RunCodegen``, the original high-level function calls are replaced with: |
| 264 | + |
| 265 | +.. code-block:: python |
| 266 | +
|
| 267 | + R.call_dps_packed(ExternFunc("fused_relax_matmul_cublas0"), (x, w, bias), ...) |
| 268 | +
|
| 269 | +At runtime, ``call_dps_packed`` invokes the externally compiled function through the |
| 270 | +``PackedFunc`` interface. The external ``runtime.Module``\ s (produced by the codegen) are |
| 271 | +imported into the final executable during ``relax.build()`` and are available via the module's |
| 272 | +function lookup mechanism. |
| 273 | + |
| 274 | +For JSON-based backends (cuBLAS, cuDNN), the runtime module deserializes the JSON graph and |
| 275 | +dispatches each node to the corresponding library API call. For source-based backends, the |
| 276 | +compiled native code is called directly. |
| 277 | + |
| 278 | + |
| 279 | +Adding a New Backend |
| 280 | +-------------------- |
| 281 | + |
| 282 | +To add support for a new external library: |
| 283 | + |
| 284 | +1. **Define patterns** in ``python/tvm/relax/backend/<target>/``: |
| 285 | + |
| 286 | + - Create DFPatterns using templates from ``patterns.py`` or custom patterns. |
| 287 | + - Write check functions to validate dtypes, shapes, and other constraints. |
| 288 | + - Register patterns with ``register_patterns()``. |
| 289 | + - Provide a ``partition_for_<backend>(mod)`` convenience function. |
| 290 | + |
| 291 | +2. **Implement codegen** in ``src/relax/backend/contrib/<target>/``: |
| 292 | + |
| 293 | + - Subclass ``JSONSerializer`` or ``CSourceCodegen``. |
| 294 | + - Implement the visitor that converts composite functions to the target format. |
| 295 | + - Register the codegen function as ``"relax.ext.<target>"``. |
| 296 | + |
| 297 | +3. **Implement runtime** (for JSON-based backends): |
| 298 | + |
| 299 | + - Create a JSON runtime module that interprets the serialized graph and dispatches |
| 300 | + to the library's API calls. |
| 301 | + - Register the runtime constructor as ``"runtime.<Target>JSONRuntimeCreate"``. |
| 302 | + |
| 303 | + |
| 304 | +Supported Backends |
| 305 | +------------------ |
| 306 | + |
| 307 | +.. list-table:: |
| 308 | + :header-rows: 1 |
| 309 | + :widths: 15 25 60 |
| 310 | + |
| 311 | + * - Backend |
| 312 | + - Patterns |
| 313 | + - Operations |
| 314 | + * - cuBLAS |
| 315 | + - ``cublas.*`` |
| 316 | + - Matmul (with bias, activation, transpose, dequantize variants) |
| 317 | + * - CUTLASS |
| 318 | + - ``cutlass.*`` |
| 319 | + - Matmul, conv2d, attention, residual blocks, decode matmul |
| 320 | + * - cuDNN |
| 321 | + - ``cudnn.*`` |
| 322 | + - Conv2d (NHWC/NCHW), stacked attention |
| 323 | + * - DNNL |
| 324 | + - ``dnnl.*`` |
| 325 | + - Matmul, conv2d (x86 CPU). Codegen exists at C++ level; patterns are |
| 326 | + defined in tests rather than pre-registered. |
| 327 | + |
| 328 | + |
| 329 | +Source Code Map |
| 330 | +--------------- |
| 331 | + |
| 332 | +.. list-table:: |
| 333 | + :header-rows: 1 |
| 334 | + :widths: 50 50 |
| 335 | + |
| 336 | + * - Path |
| 337 | + - Contents |
| 338 | + * - ``python/tvm/relax/backend/pattern_registry.py`` |
| 339 | + - Pattern registry API (register_patterns, get_patterns_with_prefix) |
| 340 | + * - ``python/tvm/relax/backend/patterns.py`` |
| 341 | + - Reusable pattern templates (make_matmul_pattern, etc.) |
| 342 | + * - ``python/tvm/relax/backend/cuda/cublas.py`` |
| 343 | + - cuBLAS patterns and partition_for_cublas |
| 344 | + * - ``python/tvm/relax/backend/cuda/cutlass.py`` |
| 345 | + - CUTLASS patterns and partition_for_cutlass |
| 346 | + * - ``python/tvm/relax/backend/cuda/cudnn.py`` |
| 347 | + - cuDNN patterns and partition_for_cudnn |
| 348 | + * - ``src/relax/backend/pattern_registry.cc`` |
| 349 | + - Pattern registry C++ implementation |
| 350 | + * - ``src/relax/transform/run_codegen.cc`` |
| 351 | + - RunCodegen pass (CodeGenRunner) |
| 352 | + * - ``src/relax/transform/merge_composite_functions.cc`` |
| 353 | + - MergeCompositeFunctions pass |
| 354 | + * - ``src/relax/backend/contrib/cublas/codegen.cc`` |
| 355 | + - cuBLAS codegen (JSONSerializer-based) |
| 356 | + * - ``src/relax/backend/contrib/cutlass/codegen.cc`` |
| 357 | + - CUTLASS codegen |
| 358 | + * - ``src/relax/backend/contrib/codegen_json/codegen_json.h`` |
| 359 | + - JSONSerializer base class |
| 360 | + * - ``src/relax/backend/contrib/codegen_c/codegen_c.h`` |
| 361 | + - CSourceCodegen base class |
0 commit comments