Skip to content

Commit 29abd38

Browse files
Bowen Fuclaude
andcommitted
feat(annotation): add TTA annotation layer (export_as, lower_as, custom_plugin)
Adds torch_tensorrt.annotation (aliased as tta) — a zero-overhead annotation layer that lets users tag regions of a PyTorch model for custom TensorRT lowering, without modifying core torch_tensorrt internals beyond a small set of generic extension hooks. Core hooks added to torch_tensorrt (outside annotation/): - _compiler.py: generic extension hook registries (register_compile_pass, register_preserved_ep_attr, register_export_context, register_post_trace_hook) + EP attribute preservation across run_decompositions() + compile pass loop - _compile.py: post-trace hook loop - _settings.py: profiling_verbosity field in CompilationSettings - _tracer.py: ExitStack wrapping for registered export context factories - _ConversionContext.py: current_node field for converter access - _TRTInterpreter.py: PREFER_AOT_PYTHON_PLUGINS flag, profiling_verbosity routing, current_node tracking, layer metadata stamping Annotation module (py/torch_tensorrt/annotation/): - export_as: context manager to tag regions during tracing - lower_as: context manager to specify custom TRT lowering (builtin/plugin/kernel) - custom_plugin: Triton/CuTile/CuTeDSL AOT QDP plugin descriptors - IR layer: region discovery, boundary validation, region views - Full test suite (599 unit/integration tests + 25 pre-Blackwell tests) Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent 2e26bfa commit 29abd38

79 files changed

Lines changed: 28236 additions & 11 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

py/torch_tensorrt/_compile.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,12 @@ def _fx_input_interface(
320320
kwarg_inputs=torchtrt_kwarg_inputs,
321321
**kwargs,
322322
)
323+
# Run post-trace hooks.
324+
from torch_tensorrt.dynamo._compiler import _post_trace_hooks
325+
for _hook in _post_trace_hooks:
326+
_result = _hook(exp_program, torchtrt_arg_inputs)
327+
if _result is not None:
328+
exp_program = _result
323329
trt_graph_module = dynamo_compile(
324330
exp_program,
325331
arg_inputs=torchtrt_arg_inputs,

py/torch_tensorrt/annotation/README.md

Lines changed: 641 additions & 0 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)