Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions docsrc/tutorials/resource_memory/engine_cache.rst
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,38 @@ The timing cache is always active and persisted at ``timing_cache_path``:
The default path is
``/tmp/torch_tensorrt_engine_cache/timing_cache.bin``.

.. note::

The timing cache is **not used with TensorRT-RTX**, which does not perform
autotuning. For TensorRT-RTX, see the *Runtime Cache* section below.

Runtime Cache (TensorRT-RTX)
-----------------------------

TensorRT-RTX uses JIT compilation at inference time. The **runtime cache** stores
these compilation results so that kernels and execution graphs are not recompiled
on subsequent runs. This is analogous to the timing cache but operates at inference
time rather than build time.

The runtime cache is automatically created when using TensorRT-RTX and can be
persisted to disk via ``runtime_cache_path``:

.. code-block:: python

trt_gm = torch_tensorrt.dynamo.compile(
exported_program,
arg_inputs=inputs,
runtime_cache_path="/data/trt_cache/runtime_cache.bin",
use_python_runtime=True,
)

The default path is
``/tmp/torch_tensorrt_engine_cache/runtime_cache.bin``.

The cache is saved to disk when the module is destroyed (garbage collected) and
loaded on subsequent compilations with the same path. File locking is used to
prevent corruption when multiple processes share the same cache file.

----

Custom Cache Backends
Expand Down
10 changes: 8 additions & 2 deletions docsrc/user_guide/compilation/compilation_settings.rst
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,7 @@ Compilation Workflow
- ``False``
- Defer TRT engine deserialization until all engines have been built.
Works around resource contraints and builder overhad but engines
may be less well tuned to their deployment resource availablity
may be less well tuned to their deployment resource availability
* - ``debug``
- ``False``
- Enable verbose TRT builder logs at ``DEBUG`` level.
Expand Down Expand Up @@ -402,7 +402,13 @@ Engine Caching
- ``/tmp/torch_tensorrt_engine_cache/timing_cache.bin``
- Path for TRT's timing cache file. The timing cache records kernel timing data
across sessions, speeding up subsequent engine builds for similar subgraphs even
when the engine cache itself is cold.
when the engine cache itself is cold. Not used for TensorRT-RTX (no autotuning).
* - ``runtime_cache_path``
- ``/tmp/torch_tensorrt_engine_cache/runtime_cache.bin``
- Path for the TensorRT-RTX runtime cache file. The runtime cache stores JIT
compilation results at inference time, preventing repeated compilation of
kernels and graphs across sessions. Uses file locking for concurrent access
safety. Only used with TensorRT-RTX; ignored for standard TensorRT.

----

Expand Down
1 change: 1 addition & 0 deletions py/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,5 @@ torch>=2.12.0.dev,<2.13.0
--extra-index-url https://pypi.ngc.nvidia.com
pyyaml
dllist
filelock
setuptools
35 changes: 32 additions & 3 deletions py/torch_tensorrt/dynamo/_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ def cross_compile_for_windows(
dryrun: bool = _defaults.DRYRUN,
hardware_compatible: bool = _defaults.HARDWARE_COMPATIBLE,
timing_cache_path: str = _defaults.TIMING_CACHE_PATH,
runtime_cache_path: str = _defaults.RUNTIME_CACHE_PATH,
dynamic_shapes_kernel_specialization_strategy: str = _defaults.DYNAMIC_SHAPES_KERNEL_SPECIALIZATION_STRATEGY,
lazy_engine_init: bool = _defaults.LAZY_ENGINE_INIT,
cache_built_engines: bool = _defaults.CACHE_BUILT_ENGINES,
reuse_cached_engines: bool = _defaults.REUSE_CACHED_ENGINES,
Expand Down Expand Up @@ -170,7 +172,9 @@ def cross_compile_for_windows(
enable_experimental_decompositions (bool): Use the full set of operator decompositions. These decompositions may not be tested but serve to make the graph easier to convert to TensorRT, potentially increasing the amount of graphs run in TensorRT.
dryrun (bool): Toggle for "Dryrun" mode, running everything except conversion to TRT and logging outputs
hardware_compatible (bool): Build the TensorRT engines compatible with GPU architectures other than that of the GPU on which the engine was built (currently works for NVIDIA Ampere and newer)
timing_cache_path (str): Path to the timing cache if it exists (or) where it will be saved after compilation
timing_cache_path (str): Path to the timing cache if it exists (or) where it will be saved after compilation. Not used for TensorRT-RTX.
runtime_cache_path (str): Path to the runtime cache for TensorRT-RTX JIT compilation results. Not used for standard TensorRT.
dynamic_shapes_kernel_specialization_strategy (str): Strategy for dynamic shape kernel specialization at runtime (TensorRT-RTX only). Options: "lazy", "eager", "none". Default: "lazy".
lazy_engine_init (bool): Defer setting up engines until the compilation of all engines is complete. Can allow larger models with multiple graph breaks to compile but can lead to oversubscription of GPU memory at runtime.
cache_built_engines (bool): Whether to save the compiled TRT engines to storage
reuse_cached_engines (bool): Whether to load the compiled TRT engines from storage
Expand Down Expand Up @@ -334,6 +338,8 @@ def cross_compile_for_windows(
"dryrun": dryrun,
"hardware_compatible": hardware_compatible,
"timing_cache_path": timing_cache_path,
"runtime_cache_path": runtime_cache_path,
"dynamic_shapes_kernel_specialization_strategy": dynamic_shapes_kernel_specialization_strategy,
"lazy_engine_init": lazy_engine_init,
"cache_built_engines": cache_built_engines,
"reuse_cached_engines": reuse_cached_engines,
Expand Down Expand Up @@ -366,6 +372,12 @@ def cross_compile_for_windows(
f"arg: {key} is not supported for cross compilation for windows feature, hence it is disabled."
)

if "runtime_cache_path" in compilation_options:
compilation_options.pop("runtime_cache_path")
logger.warning(
"runtime_cache_path is a JIT-time API and is not applicable to cross compilation for windows. Ignoring."
)

settings = CompilationSettings(**compilation_options)
logger.info("Compilation Settings: %s\n", settings)
exported_program = pre_export_lowering(exported_program, settings)
Expand Down Expand Up @@ -438,6 +450,8 @@ def compile(
dryrun: bool = _defaults.DRYRUN,
hardware_compatible: bool = _defaults.HARDWARE_COMPATIBLE,
timing_cache_path: str = _defaults.TIMING_CACHE_PATH,
runtime_cache_path: str = _defaults.RUNTIME_CACHE_PATH,
dynamic_shapes_kernel_specialization_strategy: str = _defaults.DYNAMIC_SHAPES_KERNEL_SPECIALIZATION_STRATEGY,
lazy_engine_init: bool = _defaults.LAZY_ENGINE_INIT,
cache_built_engines: bool = _defaults.CACHE_BUILT_ENGINES,
reuse_cached_engines: bool = _defaults.REUSE_CACHED_ENGINES,
Expand Down Expand Up @@ -531,7 +545,9 @@ def compile(
enable_experimental_decompositions (bool): Use the full set of operator decompositions. These decompositions may not be tested but serve to make the graph easier to convert to TensorRT, potentially increasing the amount of graphs run in TensorRT.
dryrun (bool): Toggle for "Dryrun" mode, running everything except conversion to TRT and logging outputs
hardware_compatible (bool): Build the TensorRT engines compatible with GPU architectures other than that of the GPU on which the engine was built (currently works for NVIDIA Ampere and newer)
timing_cache_path (str): Path to the timing cache if it exists (or) where it will be saved after compilation
timing_cache_path (str): Path to the timing cache if it exists (or) where it will be saved after compilation. Not used for TensorRT-RTX.
runtime_cache_path (str): Path to the runtime cache for TensorRT-RTX JIT compilation results. Not used for standard TensorRT.
dynamic_shapes_kernel_specialization_strategy (str): Strategy for dynamic shape kernel specialization at runtime (TensorRT-RTX only). Options: "lazy", "eager", "none". Default: "lazy".
lazy_engine_init (bool): Defer setting up engines until the compilation of all engines is complete. Can allow larger models with multiple graph breaks to compile but can lead to oversubscription of GPU memory at runtime.
cache_built_engines (bool): Whether to save the compiled TRT engines to storage
reuse_cached_engines (bool): Whether to load the compiled TRT engines from storage
Expand Down Expand Up @@ -738,6 +754,8 @@ def compile(
"dryrun": dryrun,
"hardware_compatible": hardware_compatible,
"timing_cache_path": timing_cache_path,
"runtime_cache_path": runtime_cache_path,
"dynamic_shapes_kernel_specialization_strategy": dynamic_shapes_kernel_specialization_strategy,
"lazy_engine_init": lazy_engine_init,
"cache_built_engines": cache_built_engines,
"reuse_cached_engines": reuse_cached_engines,
Expand Down Expand Up @@ -1150,6 +1168,8 @@ def convert_exported_program_to_serialized_trt_engine(
dryrun: bool = _defaults.DRYRUN,
hardware_compatible: bool = _defaults.HARDWARE_COMPATIBLE,
timing_cache_path: str = _defaults.TIMING_CACHE_PATH,
runtime_cache_path: str = _defaults.RUNTIME_CACHE_PATH,
dynamic_shapes_kernel_specialization_strategy: str = _defaults.DYNAMIC_SHAPES_KERNEL_SPECIALIZATION_STRATEGY,
lazy_engine_init: bool = _defaults.LAZY_ENGINE_INIT,
cache_built_engines: bool = _defaults.CACHE_BUILT_ENGINES,
reuse_cached_engines: bool = _defaults.REUSE_CACHED_ENGINES,
Expand Down Expand Up @@ -1224,7 +1244,9 @@ def convert_exported_program_to_serialized_trt_engine(
enable_experimental_decompositions (bool): Use the full set of operator decompositions. These decompositions may not be tested but serve to make the graph easier to convert to TensorRT, potentially increasing the amount of graphs run in TensorRT.
dryrun (bool): Toggle for "Dryrun" mode, running everything except conversion to TRT and logging outputs
hardware_compatible (bool): Build the TensorRT engines compatible with GPU architectures other than that of the GPU on which the engine was built (currently works for NVIDIA Ampere and newer)
timing_cache_path (str): Path to the timing cache if it exists (or) where it will be saved after compilation
timing_cache_path (str): Path to the timing cache if it exists (or) where it will be saved after compilation. Not used for TensorRT-RTX.
runtime_cache_path (str): Path to the runtime cache for TensorRT-RTX JIT compilation results. Not used for standard TensorRT.
dynamic_shapes_kernel_specialization_strategy (str): Strategy for dynamic shape kernel specialization at runtime (TensorRT-RTX only). Options: "lazy", "eager", "none". Default: "lazy".
lazy_engine_init (bool): Defer setting up engines until the compilation of all engines is complete. Can allow larger models with multiple graph breaks to compile but can lead to oversubscription of GPU memory at runtime.
cache_built_engines (bool): Whether to save the compiled TRT engines to storage
reuse_cached_engines (bool): Whether to load the compiled TRT engines from storage
Expand Down Expand Up @@ -1397,6 +1419,8 @@ def convert_exported_program_to_serialized_trt_engine(
"dryrun": dryrun,
"hardware_compatible": hardware_compatible,
"timing_cache_path": timing_cache_path,
"runtime_cache_path": runtime_cache_path,
"dynamic_shapes_kernel_specialization_strategy": dynamic_shapes_kernel_specialization_strategy,
"lazy_engine_init": lazy_engine_init,
"cache_built_engines": cache_built_engines,
"reuse_cached_engines": reuse_cached_engines,
Expand All @@ -1413,6 +1437,11 @@ def convert_exported_program_to_serialized_trt_engine(
"use_distributed_mode_trace": use_distributed_mode_trace,
"decompose_attention": decompose_attention,
}
if "runtime_cache_path" in compilation_options:
compilation_options.pop("runtime_cache_path")
logger.warning(
"runtime_cache_path is a JIT-time API and is not applicable to serialized engine export. Ignoring."
)

settings = CompilationSettings(**compilation_options)
logger.info("Compilation Settings: %s\n", settings)
Expand Down
4 changes: 4 additions & 0 deletions py/torch_tensorrt/dynamo/_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@
TIMING_CACHE_PATH = os.path.join(
tempfile.gettempdir(), "torch_tensorrt_engine_cache", "timing_cache.bin"
)
RUNTIME_CACHE_PATH = os.path.join(
tempfile.gettempdir(), "torch_tensorrt_engine_cache", "runtime_cache.bin"
)
LAZY_ENGINE_INIT = False
CACHE_BUILT_ENGINES = False
REUSE_CACHED_ENGINES = False
Expand Down Expand Up @@ -68,6 +71,7 @@
CPU_MEMORY_BUDGET = None
DYNAMICALLY_ALLOCATE_RESOURCES = False
DECOMPOSE_ATTENTION = False
DYNAMIC_SHAPES_KERNEL_SPECIALIZATION_STRATEGY = "lazy"

if platform.system() == "Linux":
import pwd
Expand Down
10 changes: 9 additions & 1 deletion py/torch_tensorrt/dynamo/_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
DLA_LOCAL_DRAM_SIZE,
DLA_SRAM_SIZE,
DRYRUN,
DYNAMIC_SHAPES_KERNEL_SPECIALIZATION_STRATEGY,
DYNAMICALLY_ALLOCATE_RESOURCES,
ENABLE_AUTOCAST,
ENABLE_CROSS_COMPILE_FOR_WINDOWS,
Expand All @@ -43,6 +44,7 @@
REFIT_IDENTICAL_ENGINE_WEIGHTS,
REQUIRE_FULL_COMPILATION,
REUSE_CACHED_ENGINES,
RUNTIME_CACHE_PATH,
SPARSE_WEIGHTS,
STRIP_ENGINE_WEIGHTS,
TILING_OPTIMIZATION_LEVEL,
Expand Down Expand Up @@ -96,7 +98,9 @@ class CompilationSettings:
TRT Engines. Prints detailed logs of the graph structure and nature of partitioning. Optionally saves the
output to a file if a string path is specified
hardware_compatible (bool): Build the TensorRT engines compatible with GPU architectures other than that of the GPU on which the engine was built (currently works for NVIDIA Ampere and newer)
timing_cache_path (str): Path to the timing cache if it exists (or) where it will be saved after compilation
timing_cache_path (str): Path to the timing cache if it exists (or) where it will be saved after compilation. Not used for TensorRT-RTX (no autotuning).
runtime_cache_path (str): Path to the runtime cache for TensorRT-RTX JIT compilation results. The cache is loaded on engine setup and saved on module cleanup. Uses file locking for concurrent access safety. Not used for standard TensorRT.
dynamic_shapes_kernel_specialization_strategy (str): Strategy for compiling shape-specialized kernels at runtime for dynamic shapes (TensorRT-RTX only). Options: "lazy" (compile in background, use fallback until ready), "eager" (compile immediately, blocking), "none" (always use fallback kernels). Default: "lazy".
cache_built_engines (bool): Whether to save the compiled TRT engines to storage
reuse_cached_engines (bool): Whether to load the compiled TRT engines from storage
use_strong_typing (bool): This flag enables strong typing in TensorRT compilation which respects the precisions set in the Pytorch model. This is useful when users have mixed precision graphs.
Expand Down Expand Up @@ -149,6 +153,10 @@ class CompilationSettings:
dryrun: Union[bool, str] = DRYRUN
hardware_compatible: bool = HARDWARE_COMPATIBLE
timing_cache_path: str = TIMING_CACHE_PATH
runtime_cache_path: str = RUNTIME_CACHE_PATH
dynamic_shapes_kernel_specialization_strategy: str = (
DYNAMIC_SHAPES_KERNEL_SPECIALIZATION_STRATEGY
)
lazy_engine_init: bool = LAZY_ENGINE_INIT
cache_built_engines: bool = CACHE_BUILT_ENGINES
reuse_cached_engines: bool = REUSE_CACHED_ENGINES
Expand Down
13 changes: 12 additions & 1 deletion py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,14 @@ def _create_timing_cache(
"""
Create a timing cache to enable faster build time for TRT engines.
By default the timing_cache_path="/tmp/timing_cache.bin"
Skipped for TensorRT-RTX since it does not use autotuning.
"""
if ENABLED_FEATURES.tensorrt_rtx:
_LOGGER.info(
"Skipping timing cache creation for TensorRT-RTX (no autotuning)"
)
return

buffer = b""
if os.path.isfile(timing_cache_path):
# Load from existing cache
Expand All @@ -394,8 +401,12 @@ def _save_timing_cache(
timing_cache_path: str,
) -> None:
"""
This is called after a TensorRT engine is built. Save the timing cache
This is called after a TensorRT engine is built. Save the timing cache.
Skipped for TensorRT-RTX since it does not use autotuning.
"""
if ENABLED_FEATURES.tensorrt_rtx:
return

timing_cache = builder_config.get_timing_cache()
os.makedirs(os.path.dirname(timing_cache_path), exist_ok=True)
with open(timing_cache_path, "wb") as timing_cache_file:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,8 @@ def __init__(
enable_experimental_decompositions (bool): Use the full set of operator decompositions. These decompositions may not be tested but serve to make the graph easier to convert to TensorRT, potentially increasing the amount of graphs run in TensorRT.
dryrun (bool): Toggle for "Dryrun" mode, running everything except conversion to TRT and logging outputs
hardware_compatible (bool): Build the TensorRT engines compatible with GPU architectures other than that of the GPU on which the engine was built (currently works for NVIDIA Ampere and newer)
timing_cache_path (str): Path to the timing cache if it exists (or) where it will be saved after compilation
timing_cache_path (str): Path to the timing cache if it exists (or) where it will be saved after compilation. Not used for TensorRT-RTX.
runtime_cache_path (str): Path to the runtime cache for TensorRT-RTX JIT compilation results. Not used for standard TensorRT.
lazy_engine_init (bool): Defer setting up engines until the compilation of all engines is complete. Can allow larger models with multiple graph breaks to compile but can lead to oversubscription of GPU memory at runtime.
enabled_precisions (Set(Union(torch.dtype, torch_tensorrt.dtype))): The set of datatypes that TensorRT can use when selecting kernels
**kwargs: Any,
Expand Down
Loading
Loading