From 5ced04fdcb3a049bb5b87b65c738ba3d18681fcd Mon Sep 17 00:00:00 2001 From: tejaswinp Date: Tue, 5 May 2026 23:57:59 -0700 Subject: [PATCH 1/9] feat(runtime): add TensorRT-RTX runtime cache, dynamic shapes strategy, and native CUDA graph support to C++ runtime - Introduce IRuntimeConfig scaffolding and bump ABI to v9 - Add runtime cache to C++ runtime for TensorRT-RTX - Add dynamic shapes kernel specialization strategy to C++ runtime - Add TensorRT-RTX native CUDA graph strategy to C++ runtime - Extract TRTRuntimeConfig - Consolidate C++ runtime tests and add model-level coverage --- core/runtime/BUILD | 16 +- core/runtime/TRTEngine.cpp | 111 +++++--- core/runtime/TRTEngine.h | 44 +++- core/runtime/TRTRuntimeConfig.cpp | 245 ++++++++++++++++++ core/runtime/TRTRuntimeConfig.h | 95 +++++++ core/runtime/execute_engine.cpp | 31 ++- core/runtime/register_jit_hooks.cpp | 5 + core/runtime/runtime.h | 5 + .../dynamo/runtime/_TorchTensorRTModule.py | 57 +++- .../runtime/_serialized_engine_layout.py | 23 +- 10 files changed, 575 insertions(+), 57 deletions(-) create mode 100644 core/runtime/TRTRuntimeConfig.cpp create mode 100644 core/runtime/TRTRuntimeConfig.h diff --git a/core/runtime/BUILD b/core/runtime/BUILD index 7f594ecea7..48d6441352 100644 --- a/core/runtime/BUILD +++ b/core/runtime/BUILD @@ -88,6 +88,7 @@ cc_library( "RTDevice.cpp", "TRTEngine.cpp", "TRTEngineProfiler.cpp", + "TRTRuntimeConfig.cpp", "execute_engine.cpp", "runtime.cpp", "runtime_utils.cpp", @@ -95,15 +96,23 @@ cc_library( hdrs = [ "Platform.h", "RTDevice.h", - "TensorRTBindingNames.h", "TRTEngine.h", "TRTEngineProfiler.h", + "TRTRuntimeConfig.h", + "TensorRTBindingNames.h", "runtime.h", ], copts = if_torch_nccl(["-DUSE_C10D_NCCL"]), linkopts = [ "-lstdc++fs", ], + local_defines = select({ + # TensorRT-RTX builds: opt into feature-gated APIs that the runtime layer + # depends on (e.g. IExecutionContext::isStreamCapturable). + ":rtx_win": ["ENABLE_FEATURE_DISABLE_RUNTIME_ALLOCATION"], + ":rtx_x86_64": ["ENABLE_FEATURE_DISABLE_RUNTIME_ALLOCATION"], + "//conditions:default": [], + }), deps = [ ":tensorrt_binding_names", "//core/plugins:torch_tensorrt_plugins", @@ -135,9 +144,9 @@ cc_library( hdrs = [ "Platform.h", "RTDevice.h", - "TensorRTBindingNames.h", "TRTEngine.h", "TRTEngineProfiler.h", + "TensorRTBindingNames.h", "runtime.h", ], deps = [ @@ -151,9 +160,10 @@ filegroup( srcs = [ "Platform.h", "RTDevice.h", - "TensorRTBindingNames.h", "TRTEngine.h", "TRTEngineProfiler.h", + "TRTRuntimeConfig.h", + "TensorRTBindingNames.h", "runtime.h", ], visibility = ["//visibility:public"], diff --git a/core/runtime/TRTEngine.cpp b/core/runtime/TRTEngine.cpp index 2b97af750c..fe928f5cc4 100644 --- a/core/runtime/TRTEngine.cpp +++ b/core/runtime/TRTEngine.cpp @@ -1,4 +1,5 @@ #include +#include #include #include @@ -70,26 +71,28 @@ void TRTEngine::record_active_input_tensor_stream_usage(const c10::cuda::CUDAStr } TRTEngine::TRTEngine( - const std::string& serialized_engine, + std::string serialized_engine, const RTDevice& cuda_device, const std::vector& _in_binding_names, const std::vector& _out_binding_names, const Platform& target_platform, bool hardware_compatible, bool requires_output_allocator, - const std::string& serialized_metadata, - const ResourceAllocationStrategy resource_allocation_strategy) + std::string serialized_metadata, + const ResourceAllocationStrategy resource_allocation_strategy, + TRTRuntimeConfig runtime_cfg) : TRTEngine( "deserialized_trt", - serialized_engine, + std::move(serialized_engine), cuda_device, _in_binding_names, _out_binding_names, target_platform, hardware_compatible, requires_output_allocator, - serialized_metadata, - resource_allocation_strategy) {} + std::move(serialized_metadata), + resource_allocation_strategy, + std::move(runtime_cfg)) {} TRTEngine::TRTEngine(std::vector serialized_info) : TRTEngine( @@ -104,7 +107,8 @@ TRTEngine::TRTEngine(std::vector serialized_info) serialized_info[SERIALIZED_METADATA_IDX], (static_cast(std::stoi(serialized_info[RESOURCE_ALLOCATION_STRATEGY_IDX])) ? ResourceAllocationStrategy::kDynamic - : ResourceAllocationStrategy::kStatic)) { + : ResourceAllocationStrategy::kStatic), + make_runtime_config_from_serialized(serialized_info)) { this->requires_native_multidevice = std::stoi(serialized_info[REQUIRES_NATIVE_MULTIDEVICE_IDX]); if (this->requires_native_multidevice) { LOG_INFO("Loaded distributed TRT engine (contains NCCL collectives); NCCL comm will be bound on first execution"); @@ -112,16 +116,18 @@ TRTEngine::TRTEngine(std::vector serialized_info) } TRTEngine::TRTEngine( - const std::string& mod_name, - const std::string& serialized_engine, + std::string mod_name, + std::string serialized_engine, const RTDevice& cuda_device, const std::vector& _in_binding_names, const std::vector& _out_binding_names, const Platform& target_platform, bool hardware_compatible, bool requires_output_allocator, - const std::string& serialized_metadata, - const ResourceAllocationStrategy resource_allocation_strategy) { + std::string serialized_metadata, + const ResourceAllocationStrategy resource_allocation_strategy, + TRTRuntimeConfig runtime_cfg) { + this->runtime_cfg = std::move(runtime_cfg); TORCHTRT_CHECK( is_supported_on_current_platform(target_platform), "This engine was not built to run on this platform (built for: " << target_platform << ", current platform: " @@ -132,7 +138,7 @@ TRTEngine::TRTEngine( auto most_compatible_device = get_most_compatible_device(cuda_device, RTDevice(), hardware_compatible); TORCHTRT_CHECK(most_compatible_device, "No compatible device was found for instantiating TensorRT engine"); - this->serialized_metadata = serialized_metadata; + this->serialized_metadata = std::move(serialized_metadata); this->requires_output_allocator = requires_output_allocator; device_info = most_compatible_device.value(); multi_gpu_device_check(); @@ -142,7 +148,7 @@ TRTEngine::TRTEngine( rt = make_trt(nvinfer1::createInferRuntime(util::logging::get_logger())); - name = slugify(mod_name); + name = slugify(std::move(mod_name)); cuda_engine = make_trt(rt->deserializeCudaEngine(serialized_engine.c_str(), serialized_engine.size())); TORCHTRT_CHECK((cuda_engine.get() != nullptr), "Unable to deserialize the TensorRT engine"); @@ -157,13 +163,7 @@ TRTEngine::TRTEngine( LOG_DEBUG( "Resource allocation strategy: " << (this->resource_allocation_strategy == ResourceAllocationStrategy::kDynamic ? "Dynamic" : "Static")); - if (this->resource_allocation_strategy == ResourceAllocationStrategy::kDynamic) { - this->exec_ctx = - make_trt(cuda_engine->createExecutionContext(nvinfer1::ExecutionContextAllocationStrategy::kUSER_MANAGED)); - } else { - this->exec_ctx = make_trt(cuda_engine->createExecutionContext()); - } - TORCHTRT_CHECK((exec_ctx.get() != nullptr), "Unable to create TensorRT execution context"); + recreate_execution_context(); // Pre-allocate placeholder for empty tensors (TensorRT requires non-null addresses) cudaMalloc(&empty_tensor_placeholder, 1); @@ -270,6 +270,9 @@ TRTEngine::TRTEngine( } TRTEngine::~TRTEngine() { + // Marked noexcept so safe to invoke from a destructor without + // explicit try/catch; any I/O error is logged internally. + runtime_cfg.save_runtime_cache(); trt_engine_profiler.reset(); exec_ctx.reset(); cuda_engine.reset(); @@ -283,8 +286,7 @@ void TRTEngine::disable_profiling() { torch::cuda::synchronize(device_info.id); profile_execution = false; trt_engine_profiler.reset(); - exec_ctx = make_trt(cuda_engine->createExecutionContext()); - TORCHTRT_CHECK((exec_ctx.get() != nullptr), "Unable to recreate TensorRT execution context"); + recreate_execution_context(); } void TRTEngine::dump_engine_layer_info_to_file(const std::string& path) { @@ -381,10 +383,7 @@ bool TRTEngine::set_device_memory_budget(int64_t budget) { trt_engine_profiler.reset(); } bool result = cuda_engine->setWeightStreamingBudgetV2(budget); - exec_ctx = make_trt(cuda_engine->createExecutionContext()); - TORCHTRT_CHECK( - (exec_ctx.get() != nullptr), - "Unable to recreate TensorRT execution context after setting new device memory budget"); + recreate_execution_context(); if (profile_execution) { enable_profiling(); } @@ -441,6 +440,7 @@ std::string TRTEngine::to_str() const { ss << " Target Platform: " << target_platform << std::endl; ss << " Resource Allocation Strategy: " << (resource_allocation_strategy == ResourceAllocationStrategy::kDynamic ? "Dynamic" : "Static") << std::endl; ss << " Multi-Device Engine: " << (requires_native_multidevice) << std::endl; + ss << runtime_cfg.to_str(); // clang-format on return ss.str(); } @@ -487,7 +487,14 @@ FlattenedState TRTEngine::__obj_flatten__() { std::tuple("requires_output_allocator", serialized_info[REQUIRES_OUTPUT_ALLOCATOR_IDX]), std::tuple("target_platform", serialized_info[TARGET_PLATFORM_IDX]), std::tuple("resource_allocation_strategy", serialized_info[RESOURCE_ALLOCATION_STRATEGY_IDX]), - std::tuple("requires_native_multidevice", serialized_info[REQUIRES_NATIVE_MULTIDEVICE_IDX])); + std::tuple("requires_native_multidevice", serialized_info[REQUIRES_NATIVE_MULTIDEVICE_IDX]) +#ifdef TRT_MAJOR_RTX + , + std::tuple("runtime_cache_path", serialized_info[RUNTIME_CACHE_PATH_IDX]), + std::tuple("dynamic_shapes_kernel_strategy", serialized_info[DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX]), + std::tuple("cuda_graph_strategy", serialized_info[CUDA_GRAPH_STRATEGY_IDX]) +#endif + ); } std::vector TRTEngine::serialize() { @@ -514,6 +521,13 @@ std::vector TRTEngine::serialize() { this->resource_allocation_strategy == ResourceAllocationStrategy::kDynamic ? "1" : "0"; serialized_info[REQUIRES_NATIVE_MULTIDEVICE_IDX] = this->requires_native_multidevice ? "1" : "0"; // rank/world_size are runtime facts (may differ at load time); not serialized. +#ifdef TRT_MAJOR_RTX + serialized_info[RUNTIME_CACHE_PATH_IDX] = runtime_cfg.runtime_cache_path; + serialized_info[DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX] = std::to_string( + static_cast>(runtime_cfg.dynamic_shapes_kernel_strategy)); + serialized_info[CUDA_GRAPH_STRATEGY_IDX] = + std::to_string(static_cast>(runtime_cfg.cuda_graph_strategy)); +#endif return serialized_info; } @@ -525,14 +539,11 @@ void TRTEngine::reset_captured_graph() { void TRTEngine::set_resource_allocation_strategy(TRTEngine::ResourceAllocationStrategy new_strategy) { if (new_strategy != this->resource_allocation_strategy) { this->resource_allocation_strategy = new_strategy; - if (this->resource_allocation_strategy == TRTEngine::ResourceAllocationStrategy::kDynamic) { - LOG_DEBUG("Setting resource allocation strategy to dynamic"); - this->exec_ctx = - make_trt(cuda_engine->createExecutionContext(nvinfer1::ExecutionContextAllocationStrategy::kUSER_MANAGED)); - } else { - LOG_DEBUG("Setting resource allocation strategy to static"); - this->exec_ctx = make_trt(cuda_engine->createExecutionContext()); - } + LOG_DEBUG( + "Setting resource allocation strategy to " + << (this->resource_allocation_strategy == TRTEngine::ResourceAllocationStrategy::kDynamic ? "dynamic" + : "static")); + recreate_execution_context(); } } @@ -642,6 +653,36 @@ void TRTEngine::release_nccl_comm() { } #endif // ENABLE_TRT_NCCL_COLLECTIVES +bool TRTEngine::is_monolithic_capturable(cudaStream_t stream) const { + return runtime_cfg.is_monolithic_capturable(exec_ctx.get(), stream); +} + +void TRTEngine::disable_rtx_native_cudagraphs() { + bool was_disabled = runtime_cfg.rtx_native_cudagraphs_disabled; + runtime_cfg.disable_rtx_native_cudagraphs(name); + if (!was_disabled && runtime_cfg.rtx_native_cudagraphs_disabled) { + // The CUDA graph strategy on the IRuntimeConfig has been flipped; rebuild exec_ctx + // so the new strategy takes effect for subsequent enqueueV3 calls. + recreate_execution_context(); + } +} + +void TRTEngine::recreate_execution_context() { + // Flush any kernels the previous execution context may have compiled into the + // runtime cache before creating the replacement. The destructor also saves, but + // doing it here guards against losing compiled kernels across profiling toggles, + // allocator changes, or process kills that happen between allocator changes and + // teardown. No-op on standard TensorRT or when no cache path is configured. + runtime_cfg.save_runtime_cache(); + runtime_cfg.ensure_initialized(cuda_engine.get()); + runtime_cfg.set_execution_context_allocation_strategy( + resource_allocation_strategy == ResourceAllocationStrategy::kDynamic + ? nvinfer1::ExecutionContextAllocationStrategy::kUSER_MANAGED + : nvinfer1::ExecutionContextAllocationStrategy::kSTATIC); + exec_ctx = make_trt(cuda_engine->createExecutionContext(runtime_cfg.config.get())); + TORCHTRT_CHECK(exec_ctx.get() != nullptr, "Unable to (re)create TensorRT execution context"); +} + } // namespace runtime } // namespace core } // namespace torch_tensorrt diff --git a/core/runtime/TRTEngine.h b/core/runtime/TRTEngine.h index c6d06dfb40..467d917baf 100644 --- a/core/runtime/TRTEngine.h +++ b/core/runtime/TRTEngine.h @@ -15,6 +15,7 @@ #include "torch/custom_class.h" #include "core/runtime/TRTEngineProfiler.h" +#include "core/runtime/TRTRuntimeConfig.h" #include "core/runtime/TensorRTBindingNames.h" #include "core/util/prelude.h" @@ -47,7 +48,14 @@ using FlattenedState = std::tuple< std::tuple, // serialized metadata std::tuple, // Platform std::tuple, // Resource Allocation Strategy - std::tuple>; // requires_native_multidevice + std::tuple // requires_native_multidevice +#ifdef TRT_MAJOR_RTX + , + std::tuple, // Runtime Cache Path (TRT-RTX) + std::tuple, // Dynamic Shapes Kernel Strategy (TRT-RTX) + std::tuple // CUDA Graph Strategy (TRT-RTX) +#endif + >; struct TorchTRTRuntimeStates { // Indicates whether CUDAGraphs were enabled in the previous execute_engine @@ -142,31 +150,33 @@ struct TRTEngine : torch::CustomClassHolder { ~TRTEngine(); TRTEngine( - const std::string& serialized_engine, + std::string serialized_engine, const RTDevice& cuda_device, const std::vector& in_binding_names, const std::vector& out_binding_names, const Platform& target_platform = get_current_platform(), bool hardware_compatible = false, bool requires_output_allocator = false, - const std::string& serialized_metadata = "", + std::string serialized_metadata = "", const TRTEngine::ResourceAllocationStrategy resource_allocation_strategy = - TRTEngine::ResourceAllocationStrategy::kStatic); + TRTEngine::ResourceAllocationStrategy::kStatic, + TRTRuntimeConfig runtime_cfg = TRTRuntimeConfig{}); TRTEngine(std::vector serialized_info); TRTEngine( - const std::string& mod_name, - const std::string& serialized_engine, + std::string mod_name, + std::string serialized_engine, const RTDevice& cuda_device, const std::vector& in_binding_names, const std::vector& out_binding_names, const Platform& target_platform = get_current_platform(), bool hardware_compatible = false, bool requires_output_allocator = false, - const std::string& serialized_metadata = "", + std::string serialized_metadata = "", const TRTEngine::ResourceAllocationStrategy resource_allocation_strategy = - TRTEngine::ResourceAllocationStrategy::kStatic); + TRTEngine::ResourceAllocationStrategy::kStatic, + TRTRuntimeConfig runtime_cfg = TRTRuntimeConfig{}); std::string to_str() const; static void verify_serialization_fmt(const std::vector& serialized_info); @@ -273,6 +283,24 @@ struct TRTEngine : torch::CustomClassHolder { ResourceAllocationStrategy resource_allocation_strategy = kStatic; void set_resource_allocation_strategy(ResourceAllocationStrategy new_strategy); ResourceAllocationStrategy get_resource_allocation_strategy(); + + // All TensorRT-RTX-specific IRuntimeConfig state lives here. On non-RTX builds this + // still owns a shared IRuntimeConfig (so the execution-context allocation strategy is + // applied via the uniform code path) but the RTX-only setters become no-ops. + TRTRuntimeConfig runtime_cfg; + + // Monolithic-capturability check used when this engine is wrapped by an outer whole-graph + // capture (e.g. CudaGraphsTorchTensorRTModule). Non-RTX builds always return true. + bool is_monolithic_capturable(cudaStream_t stream) const; + + // Disable TensorRT-RTX native CUDA graph capture on this engine (one-shot, invoked when + // an outer stream capture is detected around execute_engine). No-op on non-RTX. + void disable_rtx_native_cudagraphs(); + + private: + // Single entry point that (re)creates exec_ctx. Also creates (once) the IRuntimeConfig + // owned by runtime_cfg and applies all runtime config settings. + void recreate_execution_context(); }; } // namespace runtime diff --git a/core/runtime/TRTRuntimeConfig.cpp b/core/runtime/TRTRuntimeConfig.cpp new file mode 100644 index 0000000000..0804a0a7fa --- /dev/null +++ b/core/runtime/TRTRuntimeConfig.cpp @@ -0,0 +1,245 @@ +#include "core/runtime/TRTRuntimeConfig.h" + +#include +#include +#include +#include +#include + +#include "core/runtime/runtime.h" +#include "core/util/prelude.h" + +namespace torch_tensorrt { +namespace core { +namespace runtime { + +// File-local helpers. Kept out of the header because they are only used by this +// translation unit -- TRTEngine now consumes a TRTRuntimeConfig directly and does not +// need the enum conversion helpers. +namespace { + +[[nodiscard]] std::string to_string(DynamicShapesKernelStrategy s) { + switch (s) { + case DynamicShapesKernelStrategy::kLazy: + return "lazy"; + case DynamicShapesKernelStrategy::kEager: + return "eager"; + case DynamicShapesKernelStrategy::kNone: + return "none"; + } + TORCHTRT_CHECK( + false, + "Unexpected DynamicShapesKernelStrategy value: " + << static_cast>(s)); +} + +[[nodiscard]] std::string to_string(CudaGraphStrategyOption s) { + switch (s) { + case CudaGraphStrategyOption::kDisabled: + return "disabled"; + case CudaGraphStrategyOption::kWholeGraphCapture: + return "whole_graph_capture"; + } + TORCHTRT_CHECK( + false, + "Unexpected CudaGraphStrategyOption value: " << static_cast>(s)); +} + +[[nodiscard]] DynamicShapesKernelStrategy to_dynamic_shapes_kernel_strategy( + std::underlying_type_t v) { + TORCHTRT_CHECK( + v >= 0 && v <= 2, + "Invalid dynamic shapes kernel strategy value: " << v << ". Expected 0 (lazy), 1 (eager), or 2 (none)."); + return static_cast(v); +} + +[[nodiscard]] CudaGraphStrategyOption to_cuda_graph_strategy_option(std::underlying_type_t v) { + TORCHTRT_CHECK( + v >= 0 && v <= 1, + "Invalid CUDA graph strategy value: " << v << ". Expected 0 (disabled) or 1 (whole_graph_capture)."); + return static_cast(v); +} + +#ifdef TRT_MAJOR_RTX +// Raw cache I/O helpers. Exception-propagating; the caller wraps in try/catch at the +// TRTRuntimeConfig member level. Kept file-local because the IRuntimeCache type is +// itself TensorRT-RTX-only and tests reach this path through the member wrappers. +void load_runtime_cache(const std::string& path, nvinfer1::IRuntimeCache* cache) { + TORCHTRT_CHECK(cache != nullptr, "load_runtime_cache requires a non-null IRuntimeCache"); + if (!std::filesystem::exists(path)) { + LOG_DEBUG("No existing runtime cache at " << path); + return; + } + std::ifstream f(path, std::ios::binary); + std::vector buf((std::istreambuf_iterator(f)), std::istreambuf_iterator()); + if (buf.empty()) { + return; + } + TORCHTRT_CHECK(cache->deserialize(buf.data(), buf.size()), "IRuntimeCache::deserialize returned false for " << path); + LOG_INFO("Loaded runtime cache from " << path << " (" << buf.size() << " bytes)"); +} + +void save_runtime_cache_impl(const std::string& path, nvinfer1::IRuntimeCache* cache) { + TORCHTRT_CHECK(cache != nullptr, "save_runtime_cache requires a non-null IRuntimeCache"); + auto host_mem = make_trt(cache->serialize()); + if (!host_mem || host_mem->size() == 0) { + return; + } + std::filesystem::path fs_path(path); + if (fs_path.has_parent_path()) { + std::filesystem::create_directories(fs_path.parent_path()); + } + std::filesystem::path tmp_path = fs_path; + tmp_path += ".tmp"; + { + std::ofstream out(tmp_path, std::ios::binary); + out.write(reinterpret_cast(host_mem->data()), host_mem->size()); + } + std::filesystem::rename(tmp_path, fs_path); + LOG_INFO("Saved runtime cache to " << path << " (" << host_mem->size() << " bytes)"); +} +#endif // TRT_MAJOR_RTX + +} // namespace + +void TRTRuntimeConfig::ensure_initialized(nvinfer1::ICudaEngine* cuda_engine) { + if (config) { + return; + } + TORCHTRT_CHECK(cuda_engine != nullptr, "Cannot initialize TRTRuntimeConfig without a live ICudaEngine"); + config = make_trt(cuda_engine->createRuntimeConfig()); + TORCHTRT_CHECK(config.get() != nullptr, "Unable to create TensorRT IRuntimeConfig"); + +#ifdef TRT_MAJOR_RTX + // Runtime cache -- TRT-RTX only. + if (!runtime_cache_path.empty()) { + runtime_cache = make_trt(config->createRuntimeCache()); + if (runtime_cache.get() == nullptr) { + LOG_WARNING("Failed to create TensorRT IRuntimeCache; runtime cache will be skipped."); + } else { + try { + load_runtime_cache(runtime_cache_path, runtime_cache.get()); + } catch (const std::exception& e) { + LOG_WARNING("Failed to load runtime cache from " << runtime_cache_path << ": " << e.what()); + } + if (config->setRuntimeCache(*runtime_cache)) { + LOG_DEBUG("TensorRT-RTX runtime cache configured at " << runtime_cache_path); + } else { + LOG_WARNING("Failed to attach runtime cache to IRuntimeConfig; cache will be unused."); + runtime_cache.reset(); + } + } + } else { + LOG_DEBUG("Runtime cache disabled (no path configured)."); + } + + // Dynamic shapes kernel specialization strategy -- TRT-RTX only. + config->setDynamicShapesKernelSpecializationStrategy( + static_cast(dynamic_shapes_kernel_strategy)); + LOG_DEBUG("Dynamic shapes kernel specialization strategy set to " << to_string(dynamic_shapes_kernel_strategy)); + + // CUDA graph strategy -- TRT-RTX only. + if (!config->setCudaGraphStrategy( + cuda_graph_strategy == CudaGraphStrategyOption::kWholeGraphCapture + ? nvinfer1::CudaGraphStrategy::kWHOLE_GRAPH_CAPTURE + : nvinfer1::CudaGraphStrategy::kDISABLED)) { + LOG_WARNING("Failed to set CUDA graph strategy; continuing with default."); + } +#endif +} + +void TRTRuntimeConfig::set_execution_context_allocation_strategy( + nvinfer1::ExecutionContextAllocationStrategy strategy) const { + TORCHTRT_ASSERT(config, "TRTRuntimeConfig::config must be initialized before setting allocation strategy"); + config->setExecutionContextAllocationStrategy(strategy); +} + +bool TRTRuntimeConfig::uses_internal_capture(TORCHTRT_UNUSED bool cudagraphs_enabled) const { +#ifdef TRT_MAJOR_RTX + // On TRT-RTX the internal runtime handles capture/replay whenever a non-disabled + // strategy is set, or when subgraph cudagraphs are enabled globally. In both cases the + // caller should skip its manual at::cuda::CUDAGraph wrapper because TRT-RTX's internal + // capture would collide with it. + return cuda_graph_strategy != CudaGraphStrategyOption::kDisabled || cudagraphs_enabled; +#else + return false; +#endif +} + +void TRTRuntimeConfig::disable_rtx_native_cudagraphs(TORCHTRT_UNUSED const std::string& engine_name) noexcept { +#ifdef TRT_MAJOR_RTX + if (rtx_native_cudagraphs_disabled || cuda_graph_strategy == CudaGraphStrategyOption::kDisabled) { + return; + } + LOG_WARNING( + "Outer CUDA stream capture detected; disabling TensorRT-RTX native CUDA graph strategy on engine " + << engine_name << " for the remainder of its lifetime."); + // Persist any kernels the engine-internal capture has compiled so far; the outer + // capture will run without them otherwise, and we want future reloads to reuse them. + save_runtime_cache(); + cuda_graph_strategy = CudaGraphStrategyOption::kDisabled; + if (config && !config->setCudaGraphStrategy(nvinfer1::CudaGraphStrategy::kDISABLED)) { + LOG_WARNING("Failed to update CUDA graph strategy on IRuntimeConfig after disable."); + } + rtx_native_cudagraphs_disabled = true; +#endif +} + +bool TRTRuntimeConfig::is_monolithic_capturable( + TORCHTRT_UNUSED nvinfer1::IExecutionContext* exec_ctx, + TORCHTRT_UNUSED cudaStream_t stream) const { +#ifdef TRT_MAJOR_RTX + TORCHTRT_ASSERT(exec_ctx != nullptr, "is_monolithic_capturable requires a live IExecutionContext"); + // "lazy" kernel specialization swaps specialized kernels in mid-run, which invalidates + // captured graphs. Other strategies (eager/none) are safe when the context reports the + // stream capturable. + return exec_ctx->isStreamCapturable(stream) && dynamic_shapes_kernel_strategy != DynamicShapesKernelStrategy::kLazy; +#else + return true; +#endif +} + +void TRTRuntimeConfig::save_runtime_cache() noexcept { +#ifdef TRT_MAJOR_RTX + if (!runtime_cache || runtime_cache_path.empty()) { + return; + } + try { + save_runtime_cache_impl(runtime_cache_path, runtime_cache.get()); + } catch (const std::exception& e) { + LOG_WARNING("Failed to save runtime cache to " << runtime_cache_path << ": " << e.what()); + } catch (...) { + LOG_WARNING("Failed to save runtime cache (unknown exception)."); + } +#endif +} + +std::string TRTRuntimeConfig::to_str() const { + std::ostringstream os; + os << "Runtime Cache Path: " << (runtime_cache_path.empty() ? "" : runtime_cache_path) << std::endl; + os << "Dynamic Shapes Kernel Strategy: " << to_string(dynamic_shapes_kernel_strategy) << std::endl; + os << "CUDA Graph Strategy: " << to_string(cuda_graph_strategy) << std::endl; + return os.str(); +} + +TRTRuntimeConfig make_runtime_config_from_serialized(TORCHTRT_UNUSED const std::vector& info) { + TRTRuntimeConfig cfg; +#ifdef TRT_MAJOR_RTX + cfg.runtime_cache_path = info[RUNTIME_CACHE_PATH_IDX]; + cfg.dynamic_shapes_kernel_strategy = + to_dynamic_shapes_kernel_strategy(std::stoi(info[DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX])); + cfg.cuda_graph_strategy = to_cuda_graph_strategy_option(std::stoi(info[CUDA_GRAPH_STRATEGY_IDX])); +#endif + return cfg; +} + +std::ostream& operator<<(std::ostream& os, const TRTRuntimeConfig& cfg) { + os << "Runtime cfg {" << std::endl; + os << cfg.to_str(); + os << "}" << std::endl; + return os; +} + +} // namespace runtime +} // namespace core +} // namespace torch_tensorrt diff --git a/core/runtime/TRTRuntimeConfig.h b/core/runtime/TRTRuntimeConfig.h new file mode 100644 index 0000000000..e964706c2e --- /dev/null +++ b/core/runtime/TRTRuntimeConfig.h @@ -0,0 +1,95 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include "NvInfer.h" + +namespace torch_tensorrt { +namespace core { +namespace runtime { + +// TensorRT-RTX-only configuration for how shape-specialized kernels are compiled. +enum class DynamicShapesKernelStrategy : int32_t { + kLazy = 0, + kEager = 1, + kNone = 2, +}; + +// TensorRT-RTX-only configuration for how CUDA graph capture/replay is handled. +enum class CudaGraphStrategyOption : int32_t { + kDisabled = 0, + kWholeGraphCapture = 1, +}; + +// Encapsulates the nvinfer1::IRuntimeConfig owned by a TRTEngine along with the +// TensorRT-RTX-specific state (runtime cache, dynamic shapes kernel strategy, native +// CUDA graph strategy). All `#ifdef TRT_MAJOR_RTX` guards live in this file and its +// implementation so callers can treat this struct uniformly between RTX and standard +// TensorRT builds. +struct TRTRuntimeConfig { + // Settings - typically populated from engine deserialization before `ensure_initialized`. + std::string runtime_cache_path = ""; + DynamicShapesKernelStrategy dynamic_shapes_kernel_strategy = DynamicShapesKernelStrategy::kLazy; + CudaGraphStrategyOption cuda_graph_strategy = CudaGraphStrategyOption::kDisabled; + + // One-shot: set to true once an outer stream capture has been detected and the + // engine-internal CUDA graph strategy has been disabled for the remainder of the + // owning engine's lifetime. + bool rtx_native_cudagraphs_disabled = false; + + // Live resources. The IRuntimeConfig is lazy-constructed on first `ensure_initialized`. + std::shared_ptr config; +#ifdef TRT_MAJOR_RTX + std::shared_ptr runtime_cache; +#endif + + // Construct the IRuntimeConfig once and apply all TRT-RTX-specific settings. Safe to + // call multiple times; only the first call initializes and applies the RTX-only + // setters. On subsequent calls this is a no-op. + void ensure_initialized(nvinfer1::ICudaEngine* cuda_engine); + + // Apply (or re-apply) the execution context allocation strategy on the IRuntimeConfig. + // Available on both standard TensorRT and TensorRT-RTX via IRuntimeConfig. + void set_execution_context_allocation_strategy(nvinfer1::ExecutionContextAllocationStrategy strategy) const; + + // Returns true if the TensorRT-RTX runtime owns capture/replay for this engine so the + // caller should bypass its own at::cuda::CUDAGraph capture around enqueueV3. Always + // false on non-RTX builds. + [[nodiscard]] bool uses_internal_capture(bool cudagraphs_enabled) const; + + // One-shot: disable engine-internal CUDA graph capture. Invoked when an outer stream + // capture is detected around execute_engine, so the outer capture can contain the + // kernel launches directly. Saves the runtime cache before recreating the context so + // compiled kernels from the present run are preserved for future reloads. + void disable_rtx_native_cudagraphs(const std::string& engine_name) noexcept; + + // Whether the execution context is safe to include in an outer monolithic capture. + // Non-RTX builds always return true. + [[nodiscard]] bool is_monolithic_capturable(nvinfer1::IExecutionContext* exec_ctx, cudaStream_t stream) const; + + // Save the runtime cache to disk. Signature is `noexcept` so this is safe from a + // destructor. The underlying file I/O is performed by free functions declared below + // (non-noexcept, exception-leaky for easier testing); this member wraps them and + // swallows any exceptions. + void save_runtime_cache() noexcept; + + // Returns a human-readable summary of the runtime config. + [[nodiscard]] std::string to_str() const; +}; + +// Construct a TRTRuntimeConfig from a flattened serialization vector. Reads the +// RTX-only indices only on RTX builds; standard TRT builds return a default-initialized +// struct. +[[nodiscard]] TRTRuntimeConfig make_runtime_config_from_serialized(const std::vector& info); + +std::ostream& operator<<(std::ostream& os, const TRTRuntimeConfig& cfg); + +} // namespace runtime +} // namespace core +} // namespace torch_tensorrt diff --git a/core/runtime/execute_engine.cpp b/core/runtime/execute_engine.cpp index 6a070db3cf..80936951ef 100644 --- a/core/runtime/execute_engine.cpp +++ b/core/runtime/execute_engine.cpp @@ -241,6 +241,23 @@ std::vector execute_engine(std::vector inputs, c10::intr auto run_standard_execution = [&]() { bool cudagraphs_enabled = (CUDAGRAPHS_MODE == SUBGRAPH_CUDAGRAPHS); + // effective_cudagraphs controls the manual at::cuda::CUDAGraph path below. On TRT-RTX + // builds the engine-internal runtime owns capture/replay inside enqueueV3 whenever the + // engine has a cuda_graph_strategy set or subgraph cudagraphs are enabled; the struct + // reports that via `uses_internal_capture` so the caller skips its manual wrapper. If + // an outer stream capture is already in progress (e.g. the caller wraps this module in + // CudaGraphsTorchTensorRTModule for whole-graph capture), engine-internal capture would + // collide, so we disable it one-shot here. + bool effective_cudagraphs = cudagraphs_enabled; + if (compiled_engine->runtime_cfg.uses_internal_capture(cudagraphs_enabled)) { + effective_cudagraphs = false; + cudaStreamCaptureStatus capture_status; + cudaStreamIsCapturing(compiled_engine->engine_stream.stream(), &capture_status); + if (capture_status != cudaStreamCaptureStatusNone) { + compiled_engine->disable_rtx_native_cudagraphs(); + } + } + bool shape_changed = _validate_shapes(inputs, compiled_engine); auto current_device_id = inputs.size() > 0 ? inputs[0].device().index() : at::cuda::current_device(); @@ -262,7 +279,7 @@ std::vector execute_engine(std::vector inputs, c10::intr // Whether cudagraphs needs to record the graph on this pass auto result = compiled_engine->runtime_states.set_runtime_states( - cudagraphs_enabled, compiled_engine->use_pre_allocated_outputs, shape_changed); + effective_cudagraphs, compiled_engine->use_pre_allocated_outputs, shape_changed); bool need_cudagraphs_record = std::get<0>(result); bool can_use_pre_allocated_outputs = std::get<1>(result); @@ -282,7 +299,7 @@ std::vector execute_engine(std::vector inputs, c10::intr std::make_unique(compiled_engine->input_profile_path); } - setup_input_tensors(inputs, compiled_engine, cudagraphs_enabled, need_cudagraphs_record); + setup_input_tensors(inputs, compiled_engine, effective_cudagraphs, need_cudagraphs_record); // Check if input shapes can be inferred. int32_t const io_size{compiled_engine->cuda_engine->getNbIOTensors()}; std::vector names(io_size); @@ -314,7 +331,7 @@ std::vector execute_engine(std::vector inputs, c10::intr compiled_engine->cudagraph_output_staging_buffers[pyt_idx] = std::move(outputs[pyt_idx].clone()); } - if (cudagraphs_enabled) { + if (effective_cudagraphs) { TORCHTRT_CHECK( compiled_engine->exec_ctx->setTensorAddress( name.c_str(), compiled_engine->cudagraph_output_staging_buffers[pyt_idx].data_ptr()), @@ -346,8 +363,10 @@ std::vector execute_engine(std::vector inputs, c10::intr caller_exec_complete.block(compiled_engine->engine_stream); } - if (!cudagraphs_enabled) { - // Direct execution uses the caller buffers directly + if (!effective_cudagraphs) { + // Direct execution uses the caller buffers directly. On TRT-RTX with a + // cuda_graph_strategy set, the engine captures/replays internally during + // this enqueueV3 call. compiled_engine->exec_ctx->enqueueV3(compiled_engine->engine_stream); } else { if (need_cudagraphs_record) { @@ -384,7 +403,7 @@ std::vector execute_engine(std::vector inputs, c10::intr trt_exec_complete.block(compiled_engine->caller_stream); } - if (cudagraphs_enabled) { + if (effective_cudagraphs) { // If in CUDAGraph mode, copy persistent staging outputs to returned tensors on the caller stream. for (size_t o = 0; o < compiled_engine->cudagraph_output_staging_buffers.size(); o++) { outputs[o].copy_(compiled_engine->cudagraph_output_staging_buffers[o], false); diff --git a/core/runtime/register_jit_hooks.cpp b/core/runtime/register_jit_hooks.cpp index 7eae8bfb91..749f7c7f81 100644 --- a/core/runtime/register_jit_hooks.cpp +++ b/core/runtime/register_jit_hooks.cpp @@ -147,6 +147,11 @@ TORCH_LIBRARY(tensorrt, m) { return false; #endif }); +#ifdef TRT_MAJOR_RTX + m.def("RUNTIME_CACHE_PATH_IDX", []() -> int64_t { return RUNTIME_CACHE_PATH_IDX; }); + m.def("DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX", []() -> int64_t { return DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX; }); + m.def("CUDA_GRAPH_STRATEGY_IDX", []() -> int64_t { return CUDA_GRAPH_STRATEGY_IDX; }); +#endif m.def("_platform_linux_x86_64", []() -> std::string { auto it = get_platform_name_map().find(Platform::PlatformEnum::kLINUX_X86_64); return it->second; diff --git a/core/runtime/runtime.h b/core/runtime/runtime.h index a87bd2ca2a..25d9cd6dd2 100644 --- a/core/runtime/runtime.h +++ b/core/runtime/runtime.h @@ -41,6 +41,11 @@ typedef enum { REQUIRES_OUTPUT_ALLOCATOR_IDX, RESOURCE_ALLOCATION_STRATEGY_IDX, REQUIRES_NATIVE_MULTIDEVICE_IDX, +#ifdef TRT_MAJOR_RTX + RUNTIME_CACHE_PATH_IDX, + DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX, + CUDA_GRAPH_STRATEGY_IDX, +#endif SERIALIZATION_LEN, // NEVER USED FOR DATA, USED TO DETERMINE LENGTH OF SERIALIZED INFO } SerializedInfoIndex; diff --git a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py index 0386c97ea3..da093a519f 100644 --- a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py @@ -14,7 +14,9 @@ from torch_tensorrt.dynamo.runtime._serialized_engine_layout import ( ABI_TARGET_IDX, ABI_VERSION, + CUDA_GRAPH_STRATEGY_IDX, DEVICE_IDX, + DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX, ENGINE_IDX, HW_COMPATIBLE_IDX, INPUT_BINDING_NAMES_IDX, @@ -23,6 +25,7 @@ REQUIRES_NATIVE_MULTIDEVICE_IDX, REQUIRES_OUTPUT_ALLOCATOR_IDX, RESOURCE_ALLOCATION_STRATEGY_IDX, + RUNTIME_CACHE_PATH_IDX, SERIALIZATION_LEN, SERIALIZED_METADATA_IDX, TARGET_PLATFORM_IDX, @@ -40,6 +43,16 @@ List[str], ] +_DYNAMIC_SHAPES_KERNEL_STRATEGY_MAP: Dict[str, int] = { + "lazy": 0, + "eager": 1, + "none": 2, +} +_CUDA_GRAPH_STRATEGY_MAP: Dict[str, int] = { + "disabled": 0, + "whole_graph_capture": 1, +} + class TorchTensorRTModule(torch.nn.Module): # type: ignore[misc] """``nn.Module`` that runs a TensorRT engine inside PyTorch. @@ -132,6 +145,28 @@ def __init__( self.execute_engine_op: Any = None self.requires_output_allocator = requires_output_allocator self.dynamically_allocate_resources = settings.dynamically_allocate_resources + # TensorRT-RTX-only runtime config mirror. The engine-info serialization slots + # only exist on RTX builds (see below), but we validate the strategy names on + # every build so typos are caught regardless of backend. + self.runtime_cache_path = settings.runtime_cache_path + self.dynamic_shapes_kernel_specialization_strategy = ( + settings.dynamic_shapes_kernel_specialization_strategy + ) + if ( + self.dynamic_shapes_kernel_specialization_strategy + not in _DYNAMIC_SHAPES_KERNEL_STRATEGY_MAP + ): + raise ValueError( + f"Invalid dynamic_shapes_kernel_specialization_strategy " + f"{self.dynamic_shapes_kernel_specialization_strategy!r}; expected one of " + f"{list(_DYNAMIC_SHAPES_KERNEL_STRATEGY_MAP.keys())}" + ) + self.cuda_graph_strategy = settings.cuda_graph_strategy + if self.cuda_graph_strategy not in _CUDA_GRAPH_STRATEGY_MAP: + raise ValueError( + f"Invalid cuda_graph_strategy {self.cuda_graph_strategy!r}; expected one of " + f"{list(_CUDA_GRAPH_STRATEGY_MAP.keys())}" + ) self.symbolic_shape_expressions = symbolic_shape_expressions self.requires_native_multidevice = requires_native_multidevice self.target_platform = ( @@ -229,6 +264,18 @@ def _pack_engine_info(self) -> List[str | bytes]: int(self.requires_native_multidevice) ) # rank/world_size are runtime facts; queried from ProcessGroup at execution time + # Strategy names were validated at __init__ on every build; the index slots + # themselves only exist on RTX. + if ENABLED_FEATURES.tensorrt_rtx: + engine_info[RUNTIME_CACHE_PATH_IDX] = self.runtime_cache_path or "" + engine_info[DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX] = str( + _DYNAMIC_SHAPES_KERNEL_STRATEGY_MAP[ + self.dynamic_shapes_kernel_specialization_strategy + ] + ) + engine_info[CUDA_GRAPH_STRATEGY_IDX] = str( + _CUDA_GRAPH_STRATEGY_MAP[self.cuda_graph_strategy] + ) return engine_info @@ -332,8 +379,9 @@ def decode_metadata(encoded_metadata: bytes) -> Any: def get_extra_state(self) -> SerializedTorchTensorRTModuleFmt: if self.engine: engine_info = self._pack_engine_info() - assert isinstance(engine_info[ENGINE_IDX], (bytes, bytearray)) - engine_info[ENGINE_IDX] = base64.b64encode(engine_info[ENGINE_IDX]) + engine_bytes = engine_info[ENGINE_IDX] + assert isinstance(engine_bytes, (bytes, bytearray)) + engine_info[ENGINE_IDX] = base64.b64encode(engine_bytes) return ( self.name, engine_info, @@ -342,8 +390,9 @@ def get_extra_state(self) -> SerializedTorchTensorRTModuleFmt: ) elif self.serialized_engine: engine_info = self._pack_engine_info() - assert isinstance(engine_info[ENGINE_IDX], bytes) - engine_info[ENGINE_IDX] = base64.b64encode(engine_info[ENGINE_IDX]) + engine_bytes = engine_info[ENGINE_IDX] + assert isinstance(engine_bytes, bytes) + engine_info[ENGINE_IDX] = base64.b64encode(engine_bytes) return ( self.name, engine_info, diff --git a/py/torch_tensorrt/dynamo/runtime/_serialized_engine_layout.py b/py/torch_tensorrt/dynamo/runtime/_serialized_engine_layout.py index d4f31ba8a8..ef8938dc8d 100644 --- a/py/torch_tensorrt/dynamo/runtime/_serialized_engine_layout.py +++ b/py/torch_tensorrt/dynamo/runtime/_serialized_engine_layout.py @@ -52,7 +52,18 @@ class SerializedInfoIndex(IntEnum): REQUIRES_OUTPUT_ALLOCATOR_IDX = SerializedInfoIndex.REQUIRES_OUTPUT_ALLOCATOR_IDX RESOURCE_ALLOCATION_STRATEGY_IDX = SerializedInfoIndex.RESOURCE_ALLOCATION_STRATEGY_IDX REQUIRES_NATIVE_MULTIDEVICE_IDX = SerializedInfoIndex.REQUIRES_NATIVE_MULTIDEVICE_IDX + +# TensorRT-RTX-only indices. The C++ side gates these on ``#ifdef TRT_MAJOR_RTX``, +# so they only exist when the loaded runtime is the RTX build. +RUNTIME_CACHE_PATH_IDX = -1 +DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX = -1 +CUDA_GRAPH_STRATEGY_IDX = -1 SERIALIZATION_LEN = len(SerializedInfoIndex) +if ENABLED_FEATURES.tensorrt_rtx: + RUNTIME_CACHE_PATH_IDX = len(SerializedInfoIndex) + DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX = len(SerializedInfoIndex) + 1 + CUDA_GRAPH_STRATEGY_IDX = len(SerializedInfoIndex) + 2 + SERIALIZATION_LEN = len(SerializedInfoIndex) + 3 SERIALIZED_ENGINE_BINDING_DELIM = "%" SERIALIZED_RT_DEVICE_DELIM = "%" @@ -78,12 +89,22 @@ class SerializedInfoIndex(IntEnum): ("SERIALIZED_RT_DEVICE_DELIM", "SERIALIZED_RT_DEVICE_DELIM", str), ) +# TensorRT-RTX-only checks. The C++ ops are only registered on RTX builds. +_LAYOUT_CPP_CHECKS_RTX: tuple[_LayoutCheck, ...] = ( + ("RUNTIME_CACHE_PATH_IDX", "RUNTIME_CACHE_PATH_IDX", int), + ("DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX", "DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX", int), + ("CUDA_GRAPH_STRATEGY_IDX", "CUDA_GRAPH_STRATEGY_IDX", int), +) + def _assert_serialized_layout_matches_cpp() -> None: """Fail fast if Python layout literals diverge from ``register_jit_hooks.cpp``.""" if not ENABLED_FEATURES.torch_tensorrt_runtime: return - for op_name, global_name, normalizer in _LAYOUT_CPP_CHECKS: + checks = _LAYOUT_CPP_CHECKS + if ENABLED_FEATURES.tensorrt_rtx: + checks = checks + _LAYOUT_CPP_CHECKS_RTX + for op_name, global_name, normalizer in checks: expected = globals()[global_name] try: op = getattr(torch.ops.tensorrt, op_name) From 333bee343b5b7e692802e300b03b39e8230b82d6 Mon Sep 17 00:00:00 2001 From: tejaswinp Date: Wed, 6 May 2026 03:24:16 -0700 Subject: [PATCH 2/9] fix(runtime): route post-NCCL-release exec context through recreate_execution_context release_nccl_comm() previously rebuilt the IExecutionContext via direct calls to ICudaEngine::createExecutionContext, bypassing the TRTRuntimeConfig plumbing introduced earlier in this PR. On that path the RTX runtime cache was not flushed before context teardown, and the dynamic shapes kernel specialization and CUDA graph strategies stored on TRTRuntimeConfig were not re-applied to the new context. Delegate to recreate_execution_context() instead. It saves the runtime cache, ensures TRTRuntimeConfig is initialized, sets the allocation strategy from resource_allocation_strategy, and creates the new exec context via createExecutionContext(runtime_cfg.config.get()), keeping all strategies live across the NCCL bind/release cycle. --- core/runtime/TRTEngine.cpp | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/core/runtime/TRTEngine.cpp b/core/runtime/TRTEngine.cpp index fe928f5cc4..7aa38d3d0d 100644 --- a/core/runtime/TRTEngine.cpp +++ b/core/runtime/TRTEngine.cpp @@ -640,14 +640,7 @@ void TRTEngine::release_nccl_comm() { LOG_INFO("Releasing NCCL communicator from engine '" << this->name << "'"); torch::cuda::synchronize(device_info.id); this->exec_ctx.reset(); - if (this->resource_allocation_strategy == ResourceAllocationStrategy::kDynamic) { - this->exec_ctx = - make_trt(cuda_engine->createExecutionContext(nvinfer1::ExecutionContextAllocationStrategy::kUSER_MANAGED)); - } else { - this->exec_ctx = make_trt(cuda_engine->createExecutionContext()); - } - TORCHTRT_CHECK( - (exec_ctx.get() != nullptr), "Unable to recreate TensorRT execution context after releasing NCCL comm"); + recreate_execution_context(); this->nccl_initialized = false; LOG_INFO("NCCL communicator released from engine '" << this->name << "'"); } From 4e5243012f7b33f9d3ca2e5cdd9c01213f694a8f Mon Sep 17 00:00:00 2001 From: tejaswinp Date: Fri, 8 May 2026 17:12:03 -0700 Subject: [PATCH 3/9] fix(runtime): validate strategy strings on every build, not only RTX cuda_graph_strategy and dynamic_shapes_kernel_specialization_strategy are TRT-RTX-only at runtime, but they are accepted on every build through the public compile() / CompilationSettings surface. Their string-to-enum lookup lived inside the 'if ENABLED_FEATURES.tensorrt_rtx:' block in _pack_engine_info(), so on a standard (non-RTX) build a typo like cuda_graph_strategy="wholee_graph_capture" was silently dropped instead of raising. Hoist the membership check into TorchTensorRTModule.__init__ so that invalid strategy names always raise ValueError, regardless of backend. The RTX-gated index population in _pack_engine_info() keeps reading the maps unchanged -- only the redundant validation moves. Fixes the L1 dynamo core tests on standard-TensorRT Windows: TestCudaGraphStrategyInvalidValue::test_invalid_strategy_raises TestDynamicShapesKernelStrategyCppInvalidValue::test_invalid_strategy_raises --- py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py index da093a519f..c21f4e6496 100644 --- a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py @@ -264,8 +264,8 @@ def _pack_engine_info(self) -> List[str | bytes]: int(self.requires_native_multidevice) ) # rank/world_size are runtime facts; queried from ProcessGroup at execution time - # Strategy names were validated at __init__ on every build; the index slots - # themselves only exist on RTX. + # Strategy names are validated at __init__ time so typos fail fast on every + # build; the index slots themselves only exist on RTX. if ENABLED_FEATURES.tensorrt_rtx: engine_info[RUNTIME_CACHE_PATH_IDX] = self.runtime_cache_path or "" engine_info[DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX] = str( From c030081b5098a3c669b0589b0be515368163125d Mon Sep 17 00:00:00 2001 From: tejaswinp Date: Fri, 8 May 2026 23:32:13 -0700 Subject: [PATCH 4/9] fix(runtime): gate IRuntimeConfig usage for older TRT (Jetpack) The C++ runtime config introduced in this branch unconditionally referenced nvinfer1::IRuntimeConfig, which is only available on TensorRT-RTX and on standard TensorRT >= 10.11. The TensorRT shipped with the Jetpack l4t-r36.4 toolchain (@tensorrt_l4t) predates 10.11 and does not export this type, so the aarch64-jetpack build fails: ./core/runtime/TRTRuntimeConfig.h:47:29: error: 'IRuntimeConfig' is not a member of 'nvinfer1' Inject a TRT_HAS_IRUNTIME_CONFIG macro from core/runtime/BUILD via a 'defines = select({...})' on //core/runtime:runtime. The macro is set on every build configuration except :jetpack (RTX, SBSA, Windows, default x86_64 Linux). This is symmetric with how TRT_MAJOR_RTX and ENABLE_FEATURE_DISABLE_RUNTIME_ALLOCATION are already injected per-config in the same target. In the C++ sources, gate the IRuntimeConfig-using state with '#ifdef TRT_HAS_IRUNTIME_CONFIG' inside TRTRuntimeConfig.{h,cpp}, and expose a single TRTRuntimeConfig::create_execution_context member that selects the right createExecutionContext overload internally: - IRuntimeConfig path (>= 10.11 / RTX): set the allocation strategy on the IRuntimeConfig and call createExecutionContext(IRuntimeConfig*). - Legacy path (older TRT, e.g. Jetpack): call the legacy createExecutionContext(ExecutionContextAllocationStrategy) overload directly. The Jetpack path therefore still respects the user-requested kDynamic / kSTATIC choice. Callers in TRTEngine.cpp invoke runtime_cfg.create_execution_context(...) and stay free of any TRT_HAS_IRUNTIME_CONFIG branching. The previous public TRTRuntimeConfig::set_execution_context_allocation_strategy method had only one caller and is removed. The pre-existing TRT_MAJOR_RTX-gated runtime_cache / dynamic-shapes / cuda-graph blocks remain a strict subset of TRT_HAS_IRUNTIME_CONFIG, so behavior on TRT-RTX and on modern standard TensorRT is unchanged. Note: macro semantics are now 'is the build config named jetpack?' rather than 'does TRT actually export IRuntimeConfig?'. If @tensorrt_l4t ever bumps to 10.11+, the BUILD select needs to be updated to flip the gate on for jetpack. --- core/runtime/BUILD | 14 ++++++++++++++ core/runtime/TRTEngine.cpp | 10 ++++------ core/runtime/TRTEngine.h | 9 ++++----- core/runtime/TRTRuntimeConfig.cpp | 19 ++++++++++++++----- core/runtime/TRTRuntimeConfig.h | 28 ++++++++++++++++------------ 5 files changed, 52 insertions(+), 28 deletions(-) diff --git a/core/runtime/BUILD b/core/runtime/BUILD index 48d6441352..bb9f779929 100644 --- a/core/runtime/BUILD +++ b/core/runtime/BUILD @@ -103,6 +103,20 @@ cc_library( "runtime.h", ], copts = if_torch_nccl(["-DUSE_C10D_NCCL"]), + defines = select({ + # nvinfer1::IRuntimeConfig (and the matching ICudaEngine::createRuntimeConfig + # / createExecutionContext(IRuntimeConfig*) overloads) was introduced in + # TensorRT 10.11. The TensorRT shipped with the Jetpack l4t-r36.4 toolchain + # (@tensorrt_l4t) predates 10.11 and does not export this type. Every other + # configuration here (RTX, SBSA, Windows, default x86_64 Linux) is on a + # TensorRT >= 10.11 bundle, so it gets the macro. + # + # Gate every IRuntimeConfig-using site in core/runtime with + # `#ifdef TRT_HAS_IRUNTIME_CONFIG`; the Jetpack path falls back to the + # legacy createExecutionContext() no-arg overload. + ":jetpack": [], + "//conditions:default": ["TRT_HAS_IRUNTIME_CONFIG"], + }), linkopts = [ "-lstdc++fs", ], diff --git a/core/runtime/TRTEngine.cpp b/core/runtime/TRTEngine.cpp index 7aa38d3d0d..3a16354376 100644 --- a/core/runtime/TRTEngine.cpp +++ b/core/runtime/TRTEngine.cpp @@ -667,12 +667,10 @@ void TRTEngine::recreate_execution_context() { // allocator changes, or process kills that happen between allocator changes and // teardown. No-op on standard TensorRT or when no cache path is configured. runtime_cfg.save_runtime_cache(); - runtime_cfg.ensure_initialized(cuda_engine.get()); - runtime_cfg.set_execution_context_allocation_strategy( - resource_allocation_strategy == ResourceAllocationStrategy::kDynamic - ? nvinfer1::ExecutionContextAllocationStrategy::kUSER_MANAGED - : nvinfer1::ExecutionContextAllocationStrategy::kSTATIC); - exec_ctx = make_trt(cuda_engine->createExecutionContext(runtime_cfg.config.get())); + const auto allocation_strategy = resource_allocation_strategy == ResourceAllocationStrategy::kDynamic + ? nvinfer1::ExecutionContextAllocationStrategy::kUSER_MANAGED + : nvinfer1::ExecutionContextAllocationStrategy::kSTATIC; + exec_ctx = runtime_cfg.create_execution_context(cuda_engine.get(), allocation_strategy); TORCHTRT_CHECK(exec_ctx.get() != nullptr, "Unable to (re)create TensorRT execution context"); } diff --git a/core/runtime/TRTEngine.h b/core/runtime/TRTEngine.h index 467d917baf..6347357f7c 100644 --- a/core/runtime/TRTEngine.h +++ b/core/runtime/TRTEngine.h @@ -284,9 +284,9 @@ struct TRTEngine : torch::CustomClassHolder { void set_resource_allocation_strategy(ResourceAllocationStrategy new_strategy); ResourceAllocationStrategy get_resource_allocation_strategy(); - // All TensorRT-RTX-specific IRuntimeConfig state lives here. On non-RTX builds this - // still owns a shared IRuntimeConfig (so the execution-context allocation strategy is - // applied via the uniform code path) but the RTX-only setters become no-ops. + // Owns the IRuntimeConfig (where supported) and TRT-RTX runtime state. On older TRT + // without IRuntimeConfig (e.g. Jetpack) this just carries strategy values that get + // passed to the legacy createExecutionContext overload. TRTRuntimeConfig runtime_cfg; // Monolithic-capturability check used when this engine is wrapped by an outer whole-graph @@ -298,8 +298,7 @@ struct TRTEngine : torch::CustomClassHolder { void disable_rtx_native_cudagraphs(); private: - // Single entry point that (re)creates exec_ctx. Also creates (once) the IRuntimeConfig - // owned by runtime_cfg and applies all runtime config settings. + // Single entry point that (re)creates exec_ctx via runtime_cfg.create_execution_context. void recreate_execution_context(); }; diff --git a/core/runtime/TRTRuntimeConfig.cpp b/core/runtime/TRTRuntimeConfig.cpp index 0804a0a7fa..c0f6e8c37e 100644 --- a/core/runtime/TRTRuntimeConfig.cpp +++ b/core/runtime/TRTRuntimeConfig.cpp @@ -102,7 +102,8 @@ void save_runtime_cache_impl(const std::string& path, nvinfer1::IRuntimeCache* c } // namespace -void TRTRuntimeConfig::ensure_initialized(nvinfer1::ICudaEngine* cuda_engine) { +void TRTRuntimeConfig::ensure_initialized(TORCHTRT_UNUSED nvinfer1::ICudaEngine* cuda_engine) { +#ifdef TRT_HAS_IRUNTIME_CONFIG if (config) { return; } @@ -146,12 +147,20 @@ void TRTRuntimeConfig::ensure_initialized(nvinfer1::ICudaEngine* cuda_engine) { LOG_WARNING("Failed to set CUDA graph strategy; continuing with default."); } #endif +#endif // TRT_HAS_IRUNTIME_CONFIG } -void TRTRuntimeConfig::set_execution_context_allocation_strategy( - nvinfer1::ExecutionContextAllocationStrategy strategy) const { - TORCHTRT_ASSERT(config, "TRTRuntimeConfig::config must be initialized before setting allocation strategy"); - config->setExecutionContextAllocationStrategy(strategy); +std::shared_ptr TRTRuntimeConfig::create_execution_context( + nvinfer1::ICudaEngine* cuda_engine, + nvinfer1::ExecutionContextAllocationStrategy allocation_strategy) { + ensure_initialized(cuda_engine); +#ifdef TRT_HAS_IRUNTIME_CONFIG + config->setExecutionContextAllocationStrategy(allocation_strategy); + return make_trt(cuda_engine->createExecutionContext(config.get())); +#else + // Pre-10.11 TRT (e.g. Jetpack): use the legacy strategy overload directly. + return make_trt(cuda_engine->createExecutionContext(allocation_strategy)); +#endif } bool TRTRuntimeConfig::uses_internal_capture(TORCHTRT_UNUSED bool cudagraphs_enabled) const { diff --git a/core/runtime/TRTRuntimeConfig.h b/core/runtime/TRTRuntimeConfig.h index e964706c2e..489d59fcd0 100644 --- a/core/runtime/TRTRuntimeConfig.h +++ b/core/runtime/TRTRuntimeConfig.h @@ -27,11 +27,9 @@ enum class CudaGraphStrategyOption : int32_t { kWholeGraphCapture = 1, }; -// Encapsulates the nvinfer1::IRuntimeConfig owned by a TRTEngine along with the -// TensorRT-RTX-specific state (runtime cache, dynamic shapes kernel strategy, native -// CUDA graph strategy). All `#ifdef TRT_MAJOR_RTX` guards live in this file and its -// implementation so callers can treat this struct uniformly between RTX and standard -// TensorRT builds. +// Encapsulates the IRuntimeConfig and TRT-RTX runtime state for a TRTEngine. +// IRuntimeConfig and runtime-cache `#ifdef`s are confined to this TU; serialization- +// index plumbing keeps its own RTX gates elsewhere. struct TRTRuntimeConfig { // Settings - typically populated from engine deserialization before `ensure_initialized`. std::string runtime_cache_path = ""; @@ -43,20 +41,26 @@ struct TRTRuntimeConfig { // owning engine's lifetime. bool rtx_native_cudagraphs_disabled = false; - // Live resources. The IRuntimeConfig is lazy-constructed on first `ensure_initialized`. + // Live resources. The IRuntimeConfig is lazy-constructed on first `ensure_initialized` + // and is unavailable on TensorRT versions older than 10.11 (e.g. Jetpack). +#ifdef TRT_HAS_IRUNTIME_CONFIG std::shared_ptr config; +#endif #ifdef TRT_MAJOR_RTX std::shared_ptr runtime_cache; #endif - // Construct the IRuntimeConfig once and apply all TRT-RTX-specific settings. Safe to - // call multiple times; only the first call initializes and applies the RTX-only - // setters. On subsequent calls this is a no-op. + // Lazily construct the IRuntimeConfig and apply RTX-specific settings. Idempotent. + // No-op on builds without IRuntimeConfig (e.g. Jetpack). void ensure_initialized(nvinfer1::ICudaEngine* cuda_engine); - // Apply (or re-apply) the execution context allocation strategy on the IRuntimeConfig. - // Available on both standard TensorRT and TensorRT-RTX via IRuntimeConfig. - void set_execution_context_allocation_strategy(nvinfer1::ExecutionContextAllocationStrategy strategy) const; + // Lazy-initialize the IRuntimeConfig if needed and create an IExecutionContext that + // honors `allocation_strategy`. Selects the right `createExecutionContext` overload + // (IRuntimeConfig* vs ExecutionContextAllocationStrategy) so callers stay free of + // any TRT_HAS_IRUNTIME_CONFIG branching. + [[nodiscard]] std::shared_ptr create_execution_context( + nvinfer1::ICudaEngine* cuda_engine, + nvinfer1::ExecutionContextAllocationStrategy allocation_strategy); // Returns true if the TensorRT-RTX runtime owns capture/replay for this engine so the // caller should bypass its own at::cuda::CUDAGraph capture around enqueueV3. Always From 976c8f761bf076139167861eec3c52002ce62b06 Mon Sep 17 00:00:00 2001 From: tejaswinp Date: Thu, 28 May 2026 14:13:51 -0700 Subject: [PATCH 5/9] test(runtime): smoke tests for TRT-RTX features on C++ runtime Add tests/py/dynamo/runtime/test_001_cpp_runtime_rtx_features.py covering the three TRT-RTX features (runtime cache, dynamic shapes kernel strategy, native CUDA graph strategy) when use_python_runtime=False. The Python-runtime tests assert on Python TRTEngine attributes that the C++ engine (torch.classes.tensorrt.Engine) does not expose, so the C++ tests instead verify externally observable behavior: strategy-name typo validation in TorchTensorRTModule.__init__, compile+infer correctness via cosine similarity, and runtime-cache file persistence on destruction. --- .../test_001_cpp_runtime_rtx_features.py | 174 ++++++++++++++++++ 1 file changed, 174 insertions(+) create mode 100644 tests/py/dynamo/runtime/test_001_cpp_runtime_rtx_features.py diff --git a/tests/py/dynamo/runtime/test_001_cpp_runtime_rtx_features.py b/tests/py/dynamo/runtime/test_001_cpp_runtime_rtx_features.py new file mode 100644 index 0000000000..f92d8834be --- /dev/null +++ b/tests/py/dynamo/runtime/test_001_cpp_runtime_rtx_features.py @@ -0,0 +1,174 @@ +"""C++ runtime smoke tests for the three TensorRT-RTX runtime features. + +These tests verify the C++ runtime path (``use_python_runtime=False``) wires up +the runtime cache, dynamic-shapes kernel specialization strategy, and native +CUDA graph strategy via the serialized engine info indices. The Python runtime +equivalents live in ``test_000_runtime_cache.py``, +``test_001_dynamic_shapes_kernel_strategy.py``, and +``test_001_cuda_graph_strategy.py`` and assert via Python attributes on +:class:`TRTEngine`; the C++ ``torch.classes.tensorrt.Engine`` does not expose +those attributes to Python, so this file asserts on externally observable +behavior (compilation succeeds, inference returns correct outputs, cache files +appear on disk). +""" + +from __future__ import annotations + +import os +import shutil +import tempfile +import unittest + +import torch +import torch_tensorrt as torchtrt +from torch.testing._internal.common_utils import TestCase, run_tests +from torch_tensorrt._features import ENABLED_FEATURES +from torch_tensorrt.dynamo.utils import COSINE_THRESHOLD, cosine_similarity + + +class SimpleModel(torch.nn.Module): + def forward(self, x): + return torch.relu(x) + 1.0 + + +def _compile_cpp(**extra_kwargs): + """Compile :class:`SimpleModel` against the C++ runtime.""" + model = SimpleModel().eval().cuda() + inputs = [torch.randn(2, 3).cuda()] + kwargs = { + "ir": "dynamo", + "inputs": inputs, + "use_python_runtime": False, + "min_block_size": 1, + } + kwargs.update(extra_kwargs) + compiled = torchtrt.compile(model, **kwargs) + torch._dynamo.reset() + return compiled, inputs, model + + +def _assert_cpp_runtime_used(testcase: TestCase, compiled) -> None: + """Walk the compiled module and assert at least one C++ engine is present.""" + from torch_tensorrt.dynamo.runtime._TorchTensorRTModule import TorchTensorRTModule + from torch_tensorrt.dynamo.runtime._TRTEngine import TRTEngine + + found_cpp = False + for _, mod in compiled.named_modules(): + if isinstance(mod, TorchTensorRTModule): + testcase.assertFalse( + isinstance(mod.engine, TRTEngine), + "C++ runtime expected but found Python TRTEngine", + ) + found_cpp = True + testcase.assertTrue(found_cpp, "No TorchTensorRTModule found in compiled graph") + + +@unittest.skipIf( + not ENABLED_FEATURES.torch_tensorrt_runtime, + "C++ runtime is not available", +) +@unittest.skipIf( + not ENABLED_FEATURES.tensorrt_rtx, + "RTX-only features require TensorRT-RTX", +) +class TestCppRuntimeStrategyValidation(TestCase): + """Strategy-name typos must be rejected before engine construction.""" + + def test_invalid_dynamic_shapes_strategy_rejected(self): + with self.assertRaises(ValueError): + _compile_cpp(dynamic_shapes_kernel_specialization_strategy="invalid") + + def test_invalid_cuda_graph_strategy_rejected(self): + with self.assertRaises(ValueError): + _compile_cpp(cuda_graph_strategy="invalid_strategy") + + +@unittest.skipIf( + not ENABLED_FEATURES.torch_tensorrt_runtime, + "C++ runtime is not available", +) +@unittest.skipIf( + not ENABLED_FEATURES.tensorrt_rtx, + "RTX-only features require TensorRT-RTX", +) +class TestCppRuntimeSmoke(TestCase): + """End-to-end compile + infer + correctness on the C++ runtime.""" + + def _run_and_check(self, compiled, inputs, model): + ref = model(*inputs) + out = compiled(*[inp.clone() for inp in inputs]) + sim = cosine_similarity(ref, out) + self.assertGreaterEqual( + sim, + COSINE_THRESHOLD, + f"C++ runtime output diverged from reference (cosine={sim})", + ) + + def test_default_settings(self): + compiled, inputs, model = _compile_cpp() + _assert_cpp_runtime_used(self, compiled) + self._run_and_check(compiled, inputs, model) + + def test_eager_kernel_strategy(self): + compiled, inputs, model = _compile_cpp( + dynamic_shapes_kernel_specialization_strategy="eager" + ) + _assert_cpp_runtime_used(self, compiled) + self._run_and_check(compiled, inputs, model) + + def test_none_kernel_strategy(self): + compiled, inputs, model = _compile_cpp( + dynamic_shapes_kernel_specialization_strategy="none" + ) + _assert_cpp_runtime_used(self, compiled) + self._run_and_check(compiled, inputs, model) + + def test_whole_graph_capture(self): + compiled, inputs, model = _compile_cpp( + cuda_graph_strategy="whole_graph_capture" + ) + _assert_cpp_runtime_used(self, compiled) + self._run_and_check(compiled, inputs, model) + + +@unittest.skipIf( + not ENABLED_FEATURES.torch_tensorrt_runtime, + "C++ runtime is not available", +) +@unittest.skipIf( + not ENABLED_FEATURES.tensorrt_rtx, + "RTX-only features require TensorRT-RTX", +) +class TestCppRuntimeCachePersistence(TestCase): + """Verify the C++ runtime writes a cache file on engine destruction.""" + + def setUp(self): + self.cache_dir = tempfile.mkdtemp() + self.cache_path = os.path.join(self.cache_dir, "runtime_cache.bin") + + def tearDown(self): + shutil.rmtree(self.cache_dir, ignore_errors=True) + + def test_cache_file_written_on_destruction(self): + import gc + + compiled, inputs, model = _compile_cpp(runtime_cache_path=self.cache_path) + _assert_cpp_runtime_used(self, compiled) + # Run once so the runtime compiles some kernels. + compiled(*[inp.clone() for inp in inputs]) + del compiled + gc.collect() + torch.cuda.synchronize() + self.assertTrue( + os.path.exists(self.cache_path), + f"Runtime cache file not written to {self.cache_path}", + ) + self.assertGreater( + os.path.getsize(self.cache_path), + 0, + "Runtime cache file is empty", + ) + + +if __name__ == "__main__": + run_tests() From 142ce1d4ec63fd9e71c64ba39853ac522f0c5d4e Mon Sep 17 00:00:00 2001 From: tejaswinp Date: Thu, 28 May 2026 20:28:57 -0700 Subject: [PATCH 6/9] fix(runtime): add RTX-only entries to kSerializedInfoIndexNames verify_serialization_fmt iterates over the serialized engine info and fetches the human-readable index name from kSerializedInfoIndexNames. On RTX builds SERIALIZATION_LEN is 15 but only 12 names were initialized, leaving the remaining 3 std::array slots zero-initialized to nullptr. fprintf(\"%s\", name) on a null pointer is undefined behavior and segfaults in practice when an engine is deserialized via the def_pickle path. Gate the three RTX-only names on TRT_MAJOR_RTX to mirror the SerializedInfoIndex enum and keep the array fully initialized on both backends. --- core/runtime/runtime.h | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/core/runtime/runtime.h b/core/runtime/runtime.h index 25d9cd6dd2..3f02f8539c 100644 --- a/core/runtime/runtime.h +++ b/core/runtime/runtime.h @@ -62,6 +62,11 @@ inline constexpr std::array kSerializedInfoIndex "REQUIRES_OUTPUT_ALLOCATOR_IDX", "RESOURCE_ALLOCATION_STRATEGY_IDX", "REQUIRES_NATIVE_MULTIDEVICE_IDX", +#ifdef TRT_MAJOR_RTX + "RUNTIME_CACHE_PATH_IDX", + "DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX", + "CUDA_GRAPH_STRATEGY_IDX", +#endif }}; // For adding new serialized info indices, update above and update /dynamo/runtime/_serialized_engine_layout.py From e69ee29581241719e7957949d3d29411fa1fa8ed Mon Sep 17 00:00:00 2001 From: tejaswinp Date: Fri, 29 May 2026 01:45:32 -0700 Subject: [PATCH 7/9] refactor(runtime): address PR review comments on TRTEngine and TRTRuntimeConfig Two review-feedback changes: 1. Revert sink-by-value on pre-existing TRTEngine constructor parameters (serialized_engine, serialized_metadata, mod_name) back to const-ref. A broader sink-by-value sweep across all existing fields belongs in a separate follow-up; this PR only keeps pass-by-value + std::move for the new TRTRuntimeConfig parameter. 2. Gate the lazy-strategy capturability check on whether the engine actually has dynamic-shape inputs, mirroring the Python _is_monolithic_capturable implementation. Static-shape engines remain monolithically capturable under the lazy strategy because lazy only swaps specialized kernels mid-run on dynamic-shape inputs. - New file-local helper engine_has_dynamic_inputs() in TRTEngine.cpp walks the input bindings (including shape tensors) and reports whether any dimension is dynamic. - TRTRuntimeConfig gains a cached bool has_dynamic_inputs (default true so the conservative branch is taken if the flag is never populated); the TRTEngine constructor assigns to it once after binding names are known. - is_monolithic_capturable returns true under kLazy iff has_dynamic_inputs is false. --- core/runtime/TRTEngine.cpp | 37 +++++++++++++++++++++++-------- core/runtime/TRTEngine.h | 10 ++++----- core/runtime/TRTRuntimeConfig.cpp | 11 +++++---- core/runtime/TRTRuntimeConfig.h | 2 ++ 4 files changed, 42 insertions(+), 18 deletions(-) diff --git a/core/runtime/TRTEngine.cpp b/core/runtime/TRTEngine.cpp index 3a16354376..c191598e43 100644 --- a/core/runtime/TRTEngine.cpp +++ b/core/runtime/TRTEngine.cpp @@ -24,6 +24,23 @@ namespace torch_tensorrt { namespace core { namespace runtime { +namespace { +// TensorRT marks unspecified dimensions in dynamic-shape engines with -1. +constexpr int32_t kDynamicDim = -1; + +// Returns true iff any of the listed input bindings (including shape tensors) has a +// dynamic dimension. +[[nodiscard]] bool engine_has_dynamic_inputs( + nvinfer1::ICudaEngine* cuda_engine, + std::vector const& in_binding_names) { + TORCHTRT_CHECK(cuda_engine != nullptr, "engine_has_dynamic_inputs requires a live ICudaEngine"); + return std::any_of(std::begin(in_binding_names), std::cend(in_binding_names), [cuda_engine](std::string const& name) { + auto const dims = cuda_engine->getTensorShape(name.c_str()); + return std::any_of(dims.d, dims.d + dims.nbDims, [](int32_t d) { return d == kDynamicDim; }); + }); +} +} // namespace + std::string slugify(std::string s) { std::replace(s.begin(), s.end(), '.', '_'); return s; @@ -71,26 +88,26 @@ void TRTEngine::record_active_input_tensor_stream_usage(const c10::cuda::CUDAStr } TRTEngine::TRTEngine( - std::string serialized_engine, + const std::string& serialized_engine, const RTDevice& cuda_device, const std::vector& _in_binding_names, const std::vector& _out_binding_names, const Platform& target_platform, bool hardware_compatible, bool requires_output_allocator, - std::string serialized_metadata, + const std::string& serialized_metadata, const ResourceAllocationStrategy resource_allocation_strategy, TRTRuntimeConfig runtime_cfg) : TRTEngine( "deserialized_trt", - std::move(serialized_engine), + serialized_engine, cuda_device, _in_binding_names, _out_binding_names, target_platform, hardware_compatible, requires_output_allocator, - std::move(serialized_metadata), + serialized_metadata, resource_allocation_strategy, std::move(runtime_cfg)) {} @@ -116,15 +133,15 @@ TRTEngine::TRTEngine(std::vector serialized_info) } TRTEngine::TRTEngine( - std::string mod_name, - std::string serialized_engine, + const std::string& mod_name, + const std::string& serialized_engine, const RTDevice& cuda_device, const std::vector& _in_binding_names, const std::vector& _out_binding_names, const Platform& target_platform, bool hardware_compatible, bool requires_output_allocator, - std::string serialized_metadata, + const std::string& serialized_metadata, const ResourceAllocationStrategy resource_allocation_strategy, TRTRuntimeConfig runtime_cfg) { this->runtime_cfg = std::move(runtime_cfg); @@ -138,7 +155,7 @@ TRTEngine::TRTEngine( auto most_compatible_device = get_most_compatible_device(cuda_device, RTDevice(), hardware_compatible); TORCHTRT_CHECK(most_compatible_device, "No compatible device was found for instantiating TensorRT engine"); - this->serialized_metadata = std::move(serialized_metadata); + this->serialized_metadata = serialized_metadata; this->requires_output_allocator = requires_output_allocator; device_info = most_compatible_device.value(); multi_gpu_device_check(); @@ -148,7 +165,7 @@ TRTEngine::TRTEngine( rt = make_trt(nvinfer1::createInferRuntime(util::logging::get_logger())); - name = slugify(std::move(mod_name)); + name = slugify(mod_name); cuda_engine = make_trt(rt->deserializeCudaEngine(serialized_engine.c_str(), serialized_engine.size())); TORCHTRT_CHECK((cuda_engine.get() != nullptr), "Unable to deserialize the TensorRT engine"); @@ -251,6 +268,8 @@ TRTEngine::TRTEngine( num_io = std::make_pair(inputs_size, outputs); } + runtime_cfg.has_dynamic_inputs = engine_has_dynamic_inputs(cuda_engine.get(), in_binding_names); + #ifndef NDEBUG this->enable_profiling(); #endif diff --git a/core/runtime/TRTEngine.h b/core/runtime/TRTEngine.h index 6347357f7c..c255bfcda3 100644 --- a/core/runtime/TRTEngine.h +++ b/core/runtime/TRTEngine.h @@ -150,14 +150,14 @@ struct TRTEngine : torch::CustomClassHolder { ~TRTEngine(); TRTEngine( - std::string serialized_engine, + const std::string& serialized_engine, const RTDevice& cuda_device, const std::vector& in_binding_names, const std::vector& out_binding_names, const Platform& target_platform = get_current_platform(), bool hardware_compatible = false, bool requires_output_allocator = false, - std::string serialized_metadata = "", + const std::string& serialized_metadata = "", const TRTEngine::ResourceAllocationStrategy resource_allocation_strategy = TRTEngine::ResourceAllocationStrategy::kStatic, TRTRuntimeConfig runtime_cfg = TRTRuntimeConfig{}); @@ -165,15 +165,15 @@ struct TRTEngine : torch::CustomClassHolder { TRTEngine(std::vector serialized_info); TRTEngine( - std::string mod_name, - std::string serialized_engine, + const std::string& mod_name, + const std::string& serialized_engine, const RTDevice& cuda_device, const std::vector& in_binding_names, const std::vector& out_binding_names, const Platform& target_platform = get_current_platform(), bool hardware_compatible = false, bool requires_output_allocator = false, - std::string serialized_metadata = "", + const std::string& serialized_metadata = "", const TRTEngine::ResourceAllocationStrategy resource_allocation_strategy = TRTEngine::ResourceAllocationStrategy::kStatic, TRTRuntimeConfig runtime_cfg = TRTRuntimeConfig{}); diff --git a/core/runtime/TRTRuntimeConfig.cpp b/core/runtime/TRTRuntimeConfig.cpp index c0f6e8c37e..f92935a145 100644 --- a/core/runtime/TRTRuntimeConfig.cpp +++ b/core/runtime/TRTRuntimeConfig.cpp @@ -199,10 +199,13 @@ bool TRTRuntimeConfig::is_monolithic_capturable( TORCHTRT_UNUSED cudaStream_t stream) const { #ifdef TRT_MAJOR_RTX TORCHTRT_ASSERT(exec_ctx != nullptr, "is_monolithic_capturable requires a live IExecutionContext"); - // "lazy" kernel specialization swaps specialized kernels in mid-run, which invalidates - // captured graphs. Other strategies (eager/none) are safe when the context reports the - // stream capturable. - return exec_ctx->isStreamCapturable(stream) && dynamic_shapes_kernel_strategy != DynamicShapesKernelStrategy::kLazy; + if (!exec_ctx->isStreamCapturable(stream)) { + return false; + } + // "lazy" kernel specialization only swaps specialized kernels mid-run when an input + // has a dynamic dimension; for static-shape engines the kernels are fixed at setup and + // the captured graph stays valid. Mirrors the Python `_is_monolithic_capturable` check. + return !(dynamic_shapes_kernel_strategy == DynamicShapesKernelStrategy::kLazy && has_dynamic_inputs); #else return true; #endif diff --git a/core/runtime/TRTRuntimeConfig.h b/core/runtime/TRTRuntimeConfig.h index 489d59fcd0..6e7b8bc6ab 100644 --- a/core/runtime/TRTRuntimeConfig.h +++ b/core/runtime/TRTRuntimeConfig.h @@ -41,6 +41,8 @@ struct TRTRuntimeConfig { // owning engine's lifetime. bool rtx_native_cudagraphs_disabled = false; + bool has_dynamic_inputs = true; + // Live resources. The IRuntimeConfig is lazy-constructed on first `ensure_initialized` // and is unavailable on TensorRT versions older than 10.11 (e.g. Jetpack). #ifdef TRT_HAS_IRUNTIME_CONFIG From bab151c688cfd43e9243b47e0cdb5338fd63941f Mon Sep 17 00:00:00 2001 From: tejaswinp Date: Fri, 29 May 2026 02:36:34 -0700 Subject: [PATCH 8/9] refactor(runtime): unify serialized engine layout across TRT and TRT-RTX Drop the #ifdef TRT_MAJOR_RTX gate on the new SerializedInfoIndex entries so standard TRT and TRT-RTX engines share an identical on-disk layout. A saved program can be inspected and round-tripped across backends without a length mismatch. Add HAS_RUNTIME_CFG_IDX as a sentinel flag immediately before the three TRTRuntimeConfig slots. The producer writes \"1\" iff it authored the next three slots; the consumer treats them as defaults when the flag is \"0\". SERIALIZATION_LEN is now 16 on both backends (was 15 on RTX, 12 on standard). The Python layout-check op table merges the previous RTX-only checks into _LAYOUT_CPP_CHECKS, and register_jit_hooks always exposes the four *_IDX accessors. --- core/runtime/TRTEngine.cpp | 14 +++---- core/runtime/TRTEngine.h | 6 +-- core/runtime/TRTRuntimeConfig.cpp | 14 +++---- core/runtime/register_jit_hooks.cpp | 3 +- core/runtime/runtime.h | 7 ++-- .../dynamo/runtime/_TorchTensorRTModule.py | 26 ++++++------- .../runtime/_serialized_engine_layout.py | 38 ++++++++----------- 7 files changed, 47 insertions(+), 61 deletions(-) diff --git a/core/runtime/TRTEngine.cpp b/core/runtime/TRTEngine.cpp index c191598e43..0fe9d73730 100644 --- a/core/runtime/TRTEngine.cpp +++ b/core/runtime/TRTEngine.cpp @@ -506,14 +506,11 @@ FlattenedState TRTEngine::__obj_flatten__() { std::tuple("requires_output_allocator", serialized_info[REQUIRES_OUTPUT_ALLOCATOR_IDX]), std::tuple("target_platform", serialized_info[TARGET_PLATFORM_IDX]), std::tuple("resource_allocation_strategy", serialized_info[RESOURCE_ALLOCATION_STRATEGY_IDX]), - std::tuple("requires_native_multidevice", serialized_info[REQUIRES_NATIVE_MULTIDEVICE_IDX]) -#ifdef TRT_MAJOR_RTX - , + std::tuple("requires_native_multidevice", serialized_info[REQUIRES_NATIVE_MULTIDEVICE_IDX]), + std::tuple("has_runtime_cfg", serialized_info[HAS_RUNTIME_CFG_IDX]), std::tuple("runtime_cache_path", serialized_info[RUNTIME_CACHE_PATH_IDX]), std::tuple("dynamic_shapes_kernel_strategy", serialized_info[DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX]), - std::tuple("cuda_graph_strategy", serialized_info[CUDA_GRAPH_STRATEGY_IDX]) -#endif - ); + std::tuple("cuda_graph_strategy", serialized_info[CUDA_GRAPH_STRATEGY_IDX])); } std::vector TRTEngine::serialize() { @@ -541,12 +538,15 @@ std::vector TRTEngine::serialize() { serialized_info[REQUIRES_NATIVE_MULTIDEVICE_IDX] = this->requires_native_multidevice ? "1" : "0"; // rank/world_size are runtime facts (may differ at load time); not serialized. #ifdef TRT_MAJOR_RTX + serialized_info[HAS_RUNTIME_CFG_IDX] = "1"; +#else + serialized_info[HAS_RUNTIME_CFG_IDX] = "0"; +#endif serialized_info[RUNTIME_CACHE_PATH_IDX] = runtime_cfg.runtime_cache_path; serialized_info[DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX] = std::to_string( static_cast>(runtime_cfg.dynamic_shapes_kernel_strategy)); serialized_info[CUDA_GRAPH_STRATEGY_IDX] = std::to_string(static_cast>(runtime_cfg.cuda_graph_strategy)); -#endif return serialized_info; } diff --git a/core/runtime/TRTEngine.h b/core/runtime/TRTEngine.h index c255bfcda3..47917e9c37 100644 --- a/core/runtime/TRTEngine.h +++ b/core/runtime/TRTEngine.h @@ -48,13 +48,11 @@ using FlattenedState = std::tuple< std::tuple, // serialized metadata std::tuple, // Platform std::tuple, // Resource Allocation Strategy - std::tuple // requires_native_multidevice -#ifdef TRT_MAJOR_RTX - , + std::tuple, // requires_native_multidevice + std::tuple, // has_runtime_cfg (gates next three) std::tuple, // Runtime Cache Path (TRT-RTX) std::tuple, // Dynamic Shapes Kernel Strategy (TRT-RTX) std::tuple // CUDA Graph Strategy (TRT-RTX) -#endif >; struct TorchTRTRuntimeStates { diff --git a/core/runtime/TRTRuntimeConfig.cpp b/core/runtime/TRTRuntimeConfig.cpp index f92935a145..6f64a95cbd 100644 --- a/core/runtime/TRTRuntimeConfig.cpp +++ b/core/runtime/TRTRuntimeConfig.cpp @@ -234,14 +234,14 @@ std::string TRTRuntimeConfig::to_str() const { return os.str(); } -TRTRuntimeConfig make_runtime_config_from_serialized(TORCHTRT_UNUSED const std::vector& info) { +TRTRuntimeConfig make_runtime_config_from_serialized(const std::vector& info) { TRTRuntimeConfig cfg; -#ifdef TRT_MAJOR_RTX - cfg.runtime_cache_path = info[RUNTIME_CACHE_PATH_IDX]; - cfg.dynamic_shapes_kernel_strategy = - to_dynamic_shapes_kernel_strategy(std::stoi(info[DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX])); - cfg.cuda_graph_strategy = to_cuda_graph_strategy_option(std::stoi(info[CUDA_GRAPH_STRATEGY_IDX])); -#endif + if (info[HAS_RUNTIME_CFG_IDX] == "1") { + cfg.runtime_cache_path = info[RUNTIME_CACHE_PATH_IDX]; + cfg.dynamic_shapes_kernel_strategy = + to_dynamic_shapes_kernel_strategy(std::stoi(info[DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX])); + cfg.cuda_graph_strategy = to_cuda_graph_strategy_option(std::stoi(info[CUDA_GRAPH_STRATEGY_IDX])); + } return cfg; } diff --git a/core/runtime/register_jit_hooks.cpp b/core/runtime/register_jit_hooks.cpp index 749f7c7f81..44d1b314ca 100644 --- a/core/runtime/register_jit_hooks.cpp +++ b/core/runtime/register_jit_hooks.cpp @@ -147,11 +147,10 @@ TORCH_LIBRARY(tensorrt, m) { return false; #endif }); -#ifdef TRT_MAJOR_RTX + m.def("HAS_RUNTIME_CFG_IDX", []() -> int64_t { return HAS_RUNTIME_CFG_IDX; }); m.def("RUNTIME_CACHE_PATH_IDX", []() -> int64_t { return RUNTIME_CACHE_PATH_IDX; }); m.def("DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX", []() -> int64_t { return DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX; }); m.def("CUDA_GRAPH_STRATEGY_IDX", []() -> int64_t { return CUDA_GRAPH_STRATEGY_IDX; }); -#endif m.def("_platform_linux_x86_64", []() -> std::string { auto it = get_platform_name_map().find(Platform::PlatformEnum::kLINUX_X86_64); return it->second; diff --git a/core/runtime/runtime.h b/core/runtime/runtime.h index 3f02f8539c..2cbe73d6da 100644 --- a/core/runtime/runtime.h +++ b/core/runtime/runtime.h @@ -41,11 +41,11 @@ typedef enum { REQUIRES_OUTPUT_ALLOCATOR_IDX, RESOURCE_ALLOCATION_STRATEGY_IDX, REQUIRES_NATIVE_MULTIDEVICE_IDX, -#ifdef TRT_MAJOR_RTX + // HAS_RUNTIME_CFG_IDX gates the next three slots. When "0", their values are ignored. + HAS_RUNTIME_CFG_IDX, RUNTIME_CACHE_PATH_IDX, DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX, CUDA_GRAPH_STRATEGY_IDX, -#endif SERIALIZATION_LEN, // NEVER USED FOR DATA, USED TO DETERMINE LENGTH OF SERIALIZED INFO } SerializedInfoIndex; @@ -62,11 +62,10 @@ inline constexpr std::array kSerializedInfoIndex "REQUIRES_OUTPUT_ALLOCATOR_IDX", "RESOURCE_ALLOCATION_STRATEGY_IDX", "REQUIRES_NATIVE_MULTIDEVICE_IDX", -#ifdef TRT_MAJOR_RTX + "HAS_RUNTIME_CFG_IDX", "RUNTIME_CACHE_PATH_IDX", "DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX", "CUDA_GRAPH_STRATEGY_IDX", -#endif }}; // For adding new serialized info indices, update above and update /dynamo/runtime/_serialized_engine_layout.py diff --git a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py index c21f4e6496..77253f3deb 100644 --- a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py @@ -18,6 +18,7 @@ DEVICE_IDX, DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX, ENGINE_IDX, + HAS_RUNTIME_CFG_IDX, HW_COMPATIBLE_IDX, INPUT_BINDING_NAMES_IDX, NAME_IDX, @@ -145,9 +146,6 @@ def __init__( self.execute_engine_op: Any = None self.requires_output_allocator = requires_output_allocator self.dynamically_allocate_resources = settings.dynamically_allocate_resources - # TensorRT-RTX-only runtime config mirror. The engine-info serialization slots - # only exist on RTX builds (see below), but we validate the strategy names on - # every build so typos are caught regardless of backend. self.runtime_cache_path = settings.runtime_cache_path self.dynamic_shapes_kernel_specialization_strategy = ( settings.dynamic_shapes_kernel_specialization_strategy @@ -264,18 +262,16 @@ def _pack_engine_info(self) -> List[str | bytes]: int(self.requires_native_multidevice) ) # rank/world_size are runtime facts; queried from ProcessGroup at execution time - # Strategy names are validated at __init__ time so typos fail fast on every - # build; the index slots themselves only exist on RTX. - if ENABLED_FEATURES.tensorrt_rtx: - engine_info[RUNTIME_CACHE_PATH_IDX] = self.runtime_cache_path or "" - engine_info[DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX] = str( - _DYNAMIC_SHAPES_KERNEL_STRATEGY_MAP[ - self.dynamic_shapes_kernel_specialization_strategy - ] - ) - engine_info[CUDA_GRAPH_STRATEGY_IDX] = str( - _CUDA_GRAPH_STRATEGY_MAP[self.cuda_graph_strategy] - ) + engine_info[HAS_RUNTIME_CFG_IDX] = "1" if ENABLED_FEATURES.tensorrt_rtx else "0" + engine_info[RUNTIME_CACHE_PATH_IDX] = self.runtime_cache_path or "" + engine_info[DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX] = str( + _DYNAMIC_SHAPES_KERNEL_STRATEGY_MAP[ + self.dynamic_shapes_kernel_specialization_strategy + ] + ) + engine_info[CUDA_GRAPH_STRATEGY_IDX] = str( + _CUDA_GRAPH_STRATEGY_MAP[self.cuda_graph_strategy] + ) return engine_info diff --git a/py/torch_tensorrt/dynamo/runtime/_serialized_engine_layout.py b/py/torch_tensorrt/dynamo/runtime/_serialized_engine_layout.py index ef8938dc8d..c0bc6653b9 100644 --- a/py/torch_tensorrt/dynamo/runtime/_serialized_engine_layout.py +++ b/py/torch_tensorrt/dynamo/runtime/_serialized_engine_layout.py @@ -37,6 +37,11 @@ class SerializedInfoIndex(IntEnum): REQUIRES_OUTPUT_ALLOCATOR_IDX = 9 RESOURCE_ALLOCATION_STRATEGY_IDX = 10 REQUIRES_NATIVE_MULTIDEVICE_IDX = 11 + # HAS_RUNTIME_CFG_IDX gates the next three slots. When "0", their values are ignored. + HAS_RUNTIME_CFG_IDX = 12 + RUNTIME_CACHE_PATH_IDX = 13 + DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX = 14 + CUDA_GRAPH_STRATEGY_IDX = 15 # Module-level aliases for backward compatibility and concise access @@ -52,18 +57,13 @@ class SerializedInfoIndex(IntEnum): REQUIRES_OUTPUT_ALLOCATOR_IDX = SerializedInfoIndex.REQUIRES_OUTPUT_ALLOCATOR_IDX RESOURCE_ALLOCATION_STRATEGY_IDX = SerializedInfoIndex.RESOURCE_ALLOCATION_STRATEGY_IDX REQUIRES_NATIVE_MULTIDEVICE_IDX = SerializedInfoIndex.REQUIRES_NATIVE_MULTIDEVICE_IDX - -# TensorRT-RTX-only indices. The C++ side gates these on ``#ifdef TRT_MAJOR_RTX``, -# so they only exist when the loaded runtime is the RTX build. -RUNTIME_CACHE_PATH_IDX = -1 -DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX = -1 -CUDA_GRAPH_STRATEGY_IDX = -1 +HAS_RUNTIME_CFG_IDX = SerializedInfoIndex.HAS_RUNTIME_CFG_IDX +RUNTIME_CACHE_PATH_IDX = SerializedInfoIndex.RUNTIME_CACHE_PATH_IDX +DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX = ( + SerializedInfoIndex.DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX +) +CUDA_GRAPH_STRATEGY_IDX = SerializedInfoIndex.CUDA_GRAPH_STRATEGY_IDX SERIALIZATION_LEN = len(SerializedInfoIndex) -if ENABLED_FEATURES.tensorrt_rtx: - RUNTIME_CACHE_PATH_IDX = len(SerializedInfoIndex) - DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX = len(SerializedInfoIndex) + 1 - CUDA_GRAPH_STRATEGY_IDX = len(SerializedInfoIndex) + 2 - SERIALIZATION_LEN = len(SerializedInfoIndex) + 3 SERIALIZED_ENGINE_BINDING_DELIM = "%" SERIALIZED_RT_DEVICE_DELIM = "%" @@ -84,16 +84,13 @@ class SerializedInfoIndex(IntEnum): ("REQUIRES_OUTPUT_ALLOCATOR_IDX", "REQUIRES_OUTPUT_ALLOCATOR_IDX", int), ("RESOURCE_ALLOCATION_STRATEGY_IDX", "RESOURCE_ALLOCATION_STRATEGY_IDX", int), ("REQUIRES_NATIVE_MULTIDEVICE_IDX", "REQUIRES_NATIVE_MULTIDEVICE_IDX", int), - ("SERIALIZATION_LEN", "SERIALIZATION_LEN", int), - ("SERIALIZED_ENGINE_BINDING_DELIM", "SERIALIZED_ENGINE_BINDING_DELIM", str), - ("SERIALIZED_RT_DEVICE_DELIM", "SERIALIZED_RT_DEVICE_DELIM", str), -) - -# TensorRT-RTX-only checks. The C++ ops are only registered on RTX builds. -_LAYOUT_CPP_CHECKS_RTX: tuple[_LayoutCheck, ...] = ( + ("HAS_RUNTIME_CFG_IDX", "HAS_RUNTIME_CFG_IDX", int), ("RUNTIME_CACHE_PATH_IDX", "RUNTIME_CACHE_PATH_IDX", int), ("DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX", "DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX", int), ("CUDA_GRAPH_STRATEGY_IDX", "CUDA_GRAPH_STRATEGY_IDX", int), + ("SERIALIZATION_LEN", "SERIALIZATION_LEN", int), + ("SERIALIZED_ENGINE_BINDING_DELIM", "SERIALIZED_ENGINE_BINDING_DELIM", str), + ("SERIALIZED_RT_DEVICE_DELIM", "SERIALIZED_RT_DEVICE_DELIM", str), ) @@ -101,10 +98,7 @@ def _assert_serialized_layout_matches_cpp() -> None: """Fail fast if Python layout literals diverge from ``register_jit_hooks.cpp``.""" if not ENABLED_FEATURES.torch_tensorrt_runtime: return - checks = _LAYOUT_CPP_CHECKS - if ENABLED_FEATURES.tensorrt_rtx: - checks = checks + _LAYOUT_CPP_CHECKS_RTX - for op_name, global_name, normalizer in checks: + for op_name, global_name, normalizer in _LAYOUT_CPP_CHECKS: expected = globals()[global_name] try: op = getattr(torch.ops.tensorrt, op_name) From ae68670509f12079d008aac9a868af723f5db61b Mon Sep 17 00:00:00 2001 From: tejaswinp Date: Mon, 1 Jun 2026 20:50:18 -0700 Subject: [PATCH 9/9] test(runtime): parameterize TRT-RTX runtime tests over python/cpp Drop the standalone test_001_cpp_runtime_rtx_features.py and extend the existing python-runtime test files (test_000_runtime_cache.py, test_001_dynamic_shapes_kernel_strategy.py, test_001_cuda_graph_strategy.py) to also exercise use_python_runtime=False via parameterized.expand over [(\"python\", True), (\"cpp\", False)]. Mirrors the pattern from the original #4202; updated for the post-#4222 unified runtime where there is no process-wide backend switch. Setup / whitebox introspection tests stay python-only because they read engine.runtime_config on the Python TRTEngine, which the C++ torch.classes.tensorrt.Engine does not expose. End-to-end inference, persistence, and typo-rejection tests run on both runtimes. Add a TestSerializationIndices class that verifies the new HAS_RUNTIME_CFG flag + three TRTRuntimeConfig slots are registered identically on both backends. Add a LOG_INFO(\"[torch-TensorRT C++ runtime] TRTEngine constructed from serialized info\") in the C++-only TRTEngine(std::vector) constructor. Running compile with use_python_runtime=False under torch_tensorrt.logging.set_level(logging.INFO) produces exactly one such log line; use_python_runtime=True produces zero. --- core/runtime/TRTEngine.cpp | 5 + .../dynamo/runtime/test_000_runtime_cache.py | 178 ++++++++++++++---- .../test_001_cpp_runtime_rtx_features.py | 174 ----------------- .../runtime/test_001_cuda_graph_strategy.py | 109 +++++++++++ ...test_001_dynamic_shapes_kernel_strategy.py | 104 ++++++++++ 5 files changed, 358 insertions(+), 212 deletions(-) delete mode 100644 tests/py/dynamo/runtime/test_001_cpp_runtime_rtx_features.py diff --git a/core/runtime/TRTEngine.cpp b/core/runtime/TRTEngine.cpp index 0fe9d73730..873c758854 100644 --- a/core/runtime/TRTEngine.cpp +++ b/core/runtime/TRTEngine.cpp @@ -126,6 +126,11 @@ TRTEngine::TRTEngine(std::vector serialized_info) ? ResourceAllocationStrategy::kDynamic : ResourceAllocationStrategy::kStatic), make_runtime_config_from_serialized(serialized_info)) { + // Single visible marker that this engine was instantiated through the C++ runtime + // entry point (i.e. torch.classes.tensorrt.Engine), distinguishing it from the Python + // TRTEngine path. Tests look for this string in captured stderr to verify the + // expected backend was exercised. + LOG_INFO("[torch-TensorRT C++ runtime] TRTEngine constructed from serialized info"); this->requires_native_multidevice = std::stoi(serialized_info[REQUIRES_NATIVE_MULTIDEVICE_IDX]); if (this->requires_native_multidevice) { LOG_INFO("Loaded distributed TRT engine (contains NCCL collectives); NCCL comm will be bound on first execution"); diff --git a/tests/py/dynamo/runtime/test_000_runtime_cache.py b/tests/py/dynamo/runtime/test_000_runtime_cache.py index 05637a6146..2e9855b9a4 100644 --- a/tests/py/dynamo/runtime/test_000_runtime_cache.py +++ b/tests/py/dynamo/runtime/test_000_runtime_cache.py @@ -1,5 +1,4 @@ import gc -import logging import os import shutil import tempfile @@ -7,10 +6,11 @@ import torch import torch_tensorrt as torchtrt +from parameterized import parameterized from torch.testing._internal.common_utils import TestCase, run_tests from torch_tensorrt._features import ENABLED_FEATURES from torch_tensorrt.dynamo._defaults import RUNTIME_CACHE_PATH, TIMING_CACHE_PATH -from torch_tensorrt.dynamo._settings import CompilationSettings +from torch_tensorrt.dynamo.utils import COSINE_THRESHOLD, cosine_similarity class SimpleModel(torch.nn.Module): @@ -18,39 +18,58 @@ def forward(self, x): return torch.relu(x) + 1.0 -class TwoLayerModel(torch.nn.Module): +class ConvModel(torch.nn.Module): def __init__(self): super().__init__() - self.linear = torch.nn.Linear(8, 8) + self.conv = torch.nn.Conv2d(3, 8, 3, padding=1) def forward(self, x): - return torch.relu(self.linear(x)) + return torch.relu(self.conv(x)) -def _compile_simple(runtime_cache_path=None): - """Helper: compile SimpleModel with Python runtime, return (compiled_module, inputs).""" - model = SimpleModel().eval().cuda() - inputs = [torch.randn(2, 3).cuda()] +def _fresh_conv_model_and_inputs(seed=0): + """Deterministic ConvModel + input pair for end-to-end cache tests on either runtime.""" + torch.manual_seed(seed) + return ConvModel().eval().cuda(), [torch.randn(2, 3, 16, 16).cuda()] + + +def _compile(model, inputs, *, use_python_runtime, runtime_cache_path=None): + """Compile ``model`` through either runtime. Returns the compiled module.""" kwargs = { "ir": "dynamo", "inputs": inputs, - "use_python_runtime": True, + "use_python_runtime": use_python_runtime, "min_block_size": 1, } if runtime_cache_path is not None: kwargs["runtime_cache_path"] = runtime_cache_path compiled = torchtrt.compile(model, **kwargs) torch._dynamo.reset() - return compiled, inputs + return compiled + + +def _compile_simple(runtime_cache_path=None): + """Compile SimpleModel on the Python runtime (used by introspection setup tests).""" + model = SimpleModel().eval().cuda() + inputs = [torch.randn(2, 3).cuda()] + return ( + _compile( + model, + inputs, + use_python_runtime=True, + runtime_cache_path=runtime_cache_path, + ), + inputs, + ) def _find_python_trt_engine(compiled): - """Walk the compiled graph module and return the Python ``TRTEngine`` instance. + """Return the Python ``TRTEngine`` instance from a compiled module, if any. The C++ and Python runtimes are now both driven through ``TorchTensorRTModule`` - (``use_python_runtime`` selects which backend is constructed). For tests that - target the Python runtime specifically we look for the wrapping module and - return its ``.engine`` attribute when it's a Python ``TRTEngine``. + (``use_python_runtime`` selects which backend is constructed). Tests that target + Python-runtime introspection use this helper; C++-runtime tests rely on + externally observable behavior (cache file on disk, inference correctness). """ from torch_tensorrt.dynamo.runtime._TorchTensorRTModule import TorchTensorRTModule from torch_tensorrt.dynamo.runtime._TRTEngine import TRTEngine @@ -61,6 +80,16 @@ def _find_python_trt_engine(compiled): return None +# Parameterize end-to-end cache persistence tests over both runtime paths. The C++ +# variant is skipped inside the test body when the C++ runtime is not available. +_RUNTIMES = [("python", True), ("cpp", False)] + + +def _skip_if_cpp_unavailable(testcase, use_python_runtime): + if not use_python_runtime and not ENABLED_FEATURES.torch_tensorrt_runtime: + testcase.skipTest("C++ runtime is not available") + + @unittest.skipIf( not ENABLED_FEATURES.tensorrt_rtx, "Runtime cache is only available with TensorRT-RTX", @@ -108,7 +137,7 @@ def test_runtime_cache_path_custom(self): "Runtime cache is only available with TensorRT-RTX", ) class TestRuntimeCachePersistence(TestCase): - """Tests that runtime cache is correctly saved to and loaded from disk.""" + """Load-on-setup / save-on-destructor contract, exercised on both runtimes.""" def setUp(self): self.cache_dir = tempfile.mkdtemp() @@ -117,9 +146,16 @@ def setUp(self): def tearDown(self): shutil.rmtree(self.cache_dir, ignore_errors=True) - def test_cache_saved_on_del(self): - compiled, inputs = _compile_simple(runtime_cache_path=self.cache_path) - # Run inference to populate the cache + @parameterized.expand(_RUNTIMES) + def test_cache_saved_on_del(self, _name, use_python_runtime): + _skip_if_cpp_unavailable(self, use_python_runtime) + model, inputs = _fresh_conv_model_and_inputs() + compiled = _compile( + model, + inputs, + use_python_runtime=use_python_runtime, + runtime_cache_path=self.cache_path, + ) _ = compiled(*[inp.clone() for inp in inputs]) self.assertFalse( os.path.isfile(self.cache_path), @@ -132,8 +168,16 @@ def test_cache_saved_on_del(self): "Cache file should be created after module cleanup", ) - def test_cache_file_nonempty(self): - compiled, inputs = _compile_simple(runtime_cache_path=self.cache_path) + @parameterized.expand(_RUNTIMES) + def test_cache_file_nonempty(self, _name, use_python_runtime): + _skip_if_cpp_unavailable(self, use_python_runtime) + model, inputs = _fresh_conv_model_and_inputs() + compiled = _compile( + model, + inputs, + use_python_runtime=use_python_runtime, + runtime_cache_path=self.cache_path, + ) _ = compiled(*[inp.clone() for inp in inputs]) del compiled gc.collect() @@ -143,30 +187,54 @@ def test_cache_file_nonempty(self): "Cache file should have nonzero size", ) - def test_cache_roundtrip(self): - """Compile, infer, save. Then compile again with same cache path and verify correctness.""" - model = SimpleModel().eval().cuda() - inputs = [torch.randn(2, 3).cuda()] - ref_output = model(*inputs) - - # First compilation — populates and saves cache - compiled1, _ = _compile_simple(runtime_cache_path=self.cache_path) - _ = compiled1(*[inp.clone() for inp in inputs]) + @parameterized.expand(_RUNTIMES) + def test_cache_roundtrip(self, _name, use_python_runtime): + """Populate + save, then recompile and confirm correctness against eager output.""" + _skip_if_cpp_unavailable(self, use_python_runtime) + model, inputs = _fresh_conv_model_and_inputs() + with torch.no_grad(): + ref_output = model(*inputs) + + compiled1 = _compile( + model, + inputs, + use_python_runtime=use_python_runtime, + runtime_cache_path=self.cache_path, + ) + out1 = compiled1(*[inp.clone() for inp in inputs]) + self.assertGreater( + cosine_similarity(ref_output, out1), + COSINE_THRESHOLD, + "First compiled output should match eager", + ) del compiled1 gc.collect() self.assertTrue(os.path.isfile(self.cache_path)) - # Second compilation — should load cached data - compiled2, _ = _compile_simple(runtime_cache_path=self.cache_path) - output = compiled2(*[inp.clone() for inp in inputs]) - max_diff = float(torch.max(torch.abs(ref_output - output))) - self.assertAlmostEqual( - max_diff, 0, places=3, msg="Output mismatch after cache roundtrip" + compiled2 = _compile( + model, + inputs, + use_python_runtime=use_python_runtime, + runtime_cache_path=self.cache_path, + ) + out2 = compiled2(*[inp.clone() for inp in inputs]) + self.assertGreater( + cosine_similarity(ref_output, out2), + COSINE_THRESHOLD, + "Second compiled output (warm cache) should still match eager", ) - def test_save_creates_directory(self): + @parameterized.expand(_RUNTIMES) + def test_save_creates_directory(self, _name, use_python_runtime): + _skip_if_cpp_unavailable(self, use_python_runtime) nested_path = os.path.join(self.cache_dir, "a", "b", "c", "runtime_cache.bin") - compiled, inputs = _compile_simple(runtime_cache_path=nested_path) + model, inputs = _fresh_conv_model_and_inputs() + compiled = _compile( + model, + inputs, + use_python_runtime=use_python_runtime, + runtime_cache_path=nested_path, + ) _ = compiled(*[inp.clone() for inp in inputs]) del compiled gc.collect() @@ -292,5 +360,39 @@ def test_timing_cache_still_created(self): ) +@unittest.skipIf( + not ENABLED_FEATURES.torch_tensorrt_runtime, + "C++ runtime is not available", +) +class TestSerializationIndices(TestCase): + """The HAS_RUNTIME_CFG flag + TRTRuntimeConfig slots are present on both backends.""" + + def test_indices_match_python_layout(self): + from torch_tensorrt.dynamo.runtime._serialized_engine_layout import ( + CUDA_GRAPH_STRATEGY_IDX, + DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX, + HAS_RUNTIME_CFG_IDX, + RUNTIME_CACHE_PATH_IDX, + SERIALIZATION_LEN, + ) + + self.assertEqual(int(torch.ops.tensorrt.SERIALIZATION_LEN()), SERIALIZATION_LEN) + self.assertEqual( + int(torch.ops.tensorrt.HAS_RUNTIME_CFG_IDX()), int(HAS_RUNTIME_CFG_IDX) + ) + self.assertEqual( + int(torch.ops.tensorrt.RUNTIME_CACHE_PATH_IDX()), + int(RUNTIME_CACHE_PATH_IDX), + ) + self.assertEqual( + int(torch.ops.tensorrt.DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX()), + int(DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX), + ) + self.assertEqual( + int(torch.ops.tensorrt.CUDA_GRAPH_STRATEGY_IDX()), + int(CUDA_GRAPH_STRATEGY_IDX), + ) + + if __name__ == "__main__": run_tests() diff --git a/tests/py/dynamo/runtime/test_001_cpp_runtime_rtx_features.py b/tests/py/dynamo/runtime/test_001_cpp_runtime_rtx_features.py deleted file mode 100644 index f92d8834be..0000000000 --- a/tests/py/dynamo/runtime/test_001_cpp_runtime_rtx_features.py +++ /dev/null @@ -1,174 +0,0 @@ -"""C++ runtime smoke tests for the three TensorRT-RTX runtime features. - -These tests verify the C++ runtime path (``use_python_runtime=False``) wires up -the runtime cache, dynamic-shapes kernel specialization strategy, and native -CUDA graph strategy via the serialized engine info indices. The Python runtime -equivalents live in ``test_000_runtime_cache.py``, -``test_001_dynamic_shapes_kernel_strategy.py``, and -``test_001_cuda_graph_strategy.py`` and assert via Python attributes on -:class:`TRTEngine`; the C++ ``torch.classes.tensorrt.Engine`` does not expose -those attributes to Python, so this file asserts on externally observable -behavior (compilation succeeds, inference returns correct outputs, cache files -appear on disk). -""" - -from __future__ import annotations - -import os -import shutil -import tempfile -import unittest - -import torch -import torch_tensorrt as torchtrt -from torch.testing._internal.common_utils import TestCase, run_tests -from torch_tensorrt._features import ENABLED_FEATURES -from torch_tensorrt.dynamo.utils import COSINE_THRESHOLD, cosine_similarity - - -class SimpleModel(torch.nn.Module): - def forward(self, x): - return torch.relu(x) + 1.0 - - -def _compile_cpp(**extra_kwargs): - """Compile :class:`SimpleModel` against the C++ runtime.""" - model = SimpleModel().eval().cuda() - inputs = [torch.randn(2, 3).cuda()] - kwargs = { - "ir": "dynamo", - "inputs": inputs, - "use_python_runtime": False, - "min_block_size": 1, - } - kwargs.update(extra_kwargs) - compiled = torchtrt.compile(model, **kwargs) - torch._dynamo.reset() - return compiled, inputs, model - - -def _assert_cpp_runtime_used(testcase: TestCase, compiled) -> None: - """Walk the compiled module and assert at least one C++ engine is present.""" - from torch_tensorrt.dynamo.runtime._TorchTensorRTModule import TorchTensorRTModule - from torch_tensorrt.dynamo.runtime._TRTEngine import TRTEngine - - found_cpp = False - for _, mod in compiled.named_modules(): - if isinstance(mod, TorchTensorRTModule): - testcase.assertFalse( - isinstance(mod.engine, TRTEngine), - "C++ runtime expected but found Python TRTEngine", - ) - found_cpp = True - testcase.assertTrue(found_cpp, "No TorchTensorRTModule found in compiled graph") - - -@unittest.skipIf( - not ENABLED_FEATURES.torch_tensorrt_runtime, - "C++ runtime is not available", -) -@unittest.skipIf( - not ENABLED_FEATURES.tensorrt_rtx, - "RTX-only features require TensorRT-RTX", -) -class TestCppRuntimeStrategyValidation(TestCase): - """Strategy-name typos must be rejected before engine construction.""" - - def test_invalid_dynamic_shapes_strategy_rejected(self): - with self.assertRaises(ValueError): - _compile_cpp(dynamic_shapes_kernel_specialization_strategy="invalid") - - def test_invalid_cuda_graph_strategy_rejected(self): - with self.assertRaises(ValueError): - _compile_cpp(cuda_graph_strategy="invalid_strategy") - - -@unittest.skipIf( - not ENABLED_FEATURES.torch_tensorrt_runtime, - "C++ runtime is not available", -) -@unittest.skipIf( - not ENABLED_FEATURES.tensorrt_rtx, - "RTX-only features require TensorRT-RTX", -) -class TestCppRuntimeSmoke(TestCase): - """End-to-end compile + infer + correctness on the C++ runtime.""" - - def _run_and_check(self, compiled, inputs, model): - ref = model(*inputs) - out = compiled(*[inp.clone() for inp in inputs]) - sim = cosine_similarity(ref, out) - self.assertGreaterEqual( - sim, - COSINE_THRESHOLD, - f"C++ runtime output diverged from reference (cosine={sim})", - ) - - def test_default_settings(self): - compiled, inputs, model = _compile_cpp() - _assert_cpp_runtime_used(self, compiled) - self._run_and_check(compiled, inputs, model) - - def test_eager_kernel_strategy(self): - compiled, inputs, model = _compile_cpp( - dynamic_shapes_kernel_specialization_strategy="eager" - ) - _assert_cpp_runtime_used(self, compiled) - self._run_and_check(compiled, inputs, model) - - def test_none_kernel_strategy(self): - compiled, inputs, model = _compile_cpp( - dynamic_shapes_kernel_specialization_strategy="none" - ) - _assert_cpp_runtime_used(self, compiled) - self._run_and_check(compiled, inputs, model) - - def test_whole_graph_capture(self): - compiled, inputs, model = _compile_cpp( - cuda_graph_strategy="whole_graph_capture" - ) - _assert_cpp_runtime_used(self, compiled) - self._run_and_check(compiled, inputs, model) - - -@unittest.skipIf( - not ENABLED_FEATURES.torch_tensorrt_runtime, - "C++ runtime is not available", -) -@unittest.skipIf( - not ENABLED_FEATURES.tensorrt_rtx, - "RTX-only features require TensorRT-RTX", -) -class TestCppRuntimeCachePersistence(TestCase): - """Verify the C++ runtime writes a cache file on engine destruction.""" - - def setUp(self): - self.cache_dir = tempfile.mkdtemp() - self.cache_path = os.path.join(self.cache_dir, "runtime_cache.bin") - - def tearDown(self): - shutil.rmtree(self.cache_dir, ignore_errors=True) - - def test_cache_file_written_on_destruction(self): - import gc - - compiled, inputs, model = _compile_cpp(runtime_cache_path=self.cache_path) - _assert_cpp_runtime_used(self, compiled) - # Run once so the runtime compiles some kernels. - compiled(*[inp.clone() for inp in inputs]) - del compiled - gc.collect() - torch.cuda.synchronize() - self.assertTrue( - os.path.exists(self.cache_path), - f"Runtime cache file not written to {self.cache_path}", - ) - self.assertGreater( - os.path.getsize(self.cache_path), - 0, - "Runtime cache file is empty", - ) - - -if __name__ == "__main__": - run_tests() diff --git a/tests/py/dynamo/runtime/test_001_cuda_graph_strategy.py b/tests/py/dynamo/runtime/test_001_cuda_graph_strategy.py index 8b534be362..29d38d60b1 100644 --- a/tests/py/dynamo/runtime/test_001_cuda_graph_strategy.py +++ b/tests/py/dynamo/runtime/test_001_cuda_graph_strategy.py @@ -2,10 +2,44 @@ import torch import torch_tensorrt as torchtrt +from parameterized import parameterized from torch.testing._internal.common_utils import TestCase, run_tests from torch_tensorrt._features import ENABLED_FEATURES from torch_tensorrt.dynamo._settings import CompilationSettings +_RUNTIMES = [("python", True), ("cpp", False)] + + +def _skip_if_cpp_unavailable(testcase, use_python_runtime): + if not use_python_runtime and not ENABLED_FEATURES.torch_tensorrt_runtime: + testcase.skipTest("C++ runtime is not available") + + +class CudaGraphConvModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(3, 8, 3, padding=1) + + def forward(self, x): + return torch.relu(self.conv(x)) + + +def _compile_conv(strategy, *, use_python_runtime): + """Compile CudaGraphConvModel through the selected runtime with the given strategy.""" + model = CudaGraphConvModel().eval().cuda() + inputs = [torch.randn(2, 3, 16, 16).cuda()] + compiled = torchtrt.compile( + model, + ir="dynamo", + inputs=inputs, + enabled_precisions={torch.float32}, + use_python_runtime=use_python_runtime, + min_block_size=1, + cuda_graph_strategy=strategy, + ) + torch._dynamo.reset() + return compiled, inputs + class SimpleModel(torch.nn.Module): def forward(self, x): @@ -353,5 +387,80 @@ def test_setting_ignored_on_non_rtx(self): self.assertEqual(output.shape, (2, 3)) +_STRATEGY_RUNTIME_MATRIX = [ + (strategy, runtime_name, use_python_runtime) + for strategy in ("disabled", "whole_graph_capture") + for (runtime_name, use_python_runtime) in _RUNTIMES +] + + +@unittest.skipIf( + not ENABLED_FEATURES.tensorrt_rtx, + "CUDA graph strategy is a TensorRT-RTX feature", +) +class TestCudaGraphStrategyInference(TestCase): + """End-to-end: compile + infer with each strategy on both runtime paths.""" + + def tearDown(self): + torchtrt.runtime.set_cudagraphs_mode(False) + + @parameterized.expand(_STRATEGY_RUNTIME_MATRIX) + def test_strategy_inference(self, strategy, _runtime_name, use_python_runtime): + _skip_if_cpp_unavailable(self, use_python_runtime) + compiled, inputs = _compile_conv( + strategy, use_python_runtime=use_python_runtime + ) + y = compiled(*[inp.clone() for inp in inputs]) + self.assertEqual(tuple(y.shape), (2, 8, 16, 16)) + self.assertTrue(torch.isfinite(y).all().item()) + + @parameterized.expand(_RUNTIMES) + def test_whole_graph_capture_with_subgraph_cudagraphs( + self, _name, use_python_runtime + ): + """Subgraph cudagraph mode + RTX strategy: RTX-native should take over without errors.""" + _skip_if_cpp_unavailable(self, use_python_runtime) + compiled, inputs = _compile_conv( + "whole_graph_capture", use_python_runtime=use_python_runtime + ) + torchtrt.runtime.set_cudagraphs_mode(True) + y = compiled(*[inp.clone() for inp in inputs]) + self.assertEqual(tuple(y.shape), (2, 8, 16, 16)) + self.assertTrue(torch.isfinite(y).all().item()) + + @parameterized.expand(_RUNTIMES) + def test_repeated_inference(self, _name, use_python_runtime): + """Repeated inference exercises the RTX-native capture/replay path.""" + _skip_if_cpp_unavailable(self, use_python_runtime) + compiled, inputs = _compile_conv( + "whole_graph_capture", use_python_runtime=use_python_runtime + ) + ref = compiled(*[inp.clone() for inp in inputs]) + for _ in range(4): + out = compiled(*[inp.clone() for inp in inputs]) + self.assertEqual(out.shape, ref.shape) + self.assertTrue(torch.isfinite(out).all().item()) + + +class TestCudaGraphStrategyInvalidValue(TestCase): + """Invalid strategy names are rejected at TorchTensorRTModule.__init__ on any backend.""" + + @parameterized.expand(_RUNTIMES) + def test_invalid_strategy_raises(self, _name, use_python_runtime): + _skip_if_cpp_unavailable(self, use_python_runtime) + model = CudaGraphConvModel().eval().cuda() + inputs = [torch.randn(2, 3, 16, 16).cuda()] + with self.assertRaises((ValueError, RuntimeError)): + torchtrt.compile( + model, + ir="dynamo", + inputs=inputs, + enabled_precisions={torch.float32}, + use_python_runtime=use_python_runtime, + min_block_size=1, + cuda_graph_strategy="not_a_real_strategy", + ) + + if __name__ == "__main__": run_tests() diff --git a/tests/py/dynamo/runtime/test_001_dynamic_shapes_kernel_strategy.py b/tests/py/dynamo/runtime/test_001_dynamic_shapes_kernel_strategy.py index 11998f9b7a..408fb159fc 100644 --- a/tests/py/dynamo/runtime/test_001_dynamic_shapes_kernel_strategy.py +++ b/tests/py/dynamo/runtime/test_001_dynamic_shapes_kernel_strategy.py @@ -2,16 +2,59 @@ import torch import torch_tensorrt as torchtrt +from parameterized import parameterized from torch.testing._internal.common_utils import TestCase, run_tests from torch_tensorrt._features import ENABLED_FEATURES from torch_tensorrt.dynamo._settings import CompilationSettings +_STRATEGIES = [("lazy",), ("eager",), ("none",)] + class SimpleModel(torch.nn.Module): def forward(self, x): return torch.relu(x) + 1.0 +class DynamicConvModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv1 = torch.nn.Conv2d(3, 16, 3, padding=1) + self.conv2 = torch.nn.Conv2d(16, 8, 3, padding=1) + + def forward(self, x): + return torch.relu(self.conv2(torch.relu(self.conv1(x)))) + + +_RUNTIMES = [("python", True), ("cpp", False)] + + +def _skip_if_cpp_unavailable(testcase, use_python_runtime): + if not use_python_runtime and not ENABLED_FEATURES.torch_tensorrt_runtime: + testcase.skipTest("C++ runtime is not available") + + +def _compile_dynamic_conv(strategy, *, use_python_runtime): + """Compile DynamicConvModel through the selected runtime with the given strategy.""" + model = DynamicConvModel().eval().cuda() + inp = torchtrt.Input( + min_shape=(1, 3, 16, 16), + opt_shape=(2, 3, 16, 16), + max_shape=(4, 3, 16, 16), + dtype=torch.float32, + ) + compiled = torchtrt.compile( + model, + ir="dynamo", + inputs=[inp], + enabled_precisions={torch.float32}, + use_python_runtime=use_python_runtime, + min_block_size=1, + dynamic_shapes_kernel_specialization_strategy=strategy, + ) + torch._dynamo.reset() + return compiled + + def _compile_simple(**extra_kwargs): """Helper: compile SimpleModel with dynamic shapes and Python runtime.""" model = SimpleModel().eval().cuda() @@ -147,5 +190,66 @@ def test_setting_ignored_on_non_rtx(self): self.assertEqual(output.shape, (2, 3)) +_STRATEGY_RUNTIME_MATRIX = [ + (strategy, runtime_name, use_python_runtime) + for (strategy,) in _STRATEGIES + for (runtime_name, use_python_runtime) in _RUNTIMES +] + + +@unittest.skipIf( + not ENABLED_FEATURES.tensorrt_rtx, + "Dynamic shapes kernel strategy is a TensorRT-RTX feature", +) +class TestDynamicShapesKernelStrategyInference(TestCase): + """End-to-end: compile + infer with each strategy on both runtime paths.""" + + @parameterized.expand(_STRATEGY_RUNTIME_MATRIX) + def test_strategy_inference(self, strategy, _runtime_name, use_python_runtime): + _skip_if_cpp_unavailable(self, use_python_runtime) + compiled = _compile_dynamic_conv( + strategy, use_python_runtime=use_python_runtime + ) + x = torch.randn(2, 3, 16, 16, device="cuda") + y = compiled(x) + self.assertEqual(tuple(y.shape), (2, 8, 16, 16)) + self.assertTrue(torch.isfinite(y).all().item()) + + @parameterized.expand(_RUNTIMES) + def test_dynamic_shape_with_eager(self, _name, use_python_runtime): + """Exercise shape changes under eager kernel specialization.""" + _skip_if_cpp_unavailable(self, use_python_runtime) + compiled = _compile_dynamic_conv("eager", use_python_runtime=use_python_runtime) + for batch in (1, 2, 3, 4): + x = torch.randn(batch, 3, 16, 16, device="cuda") + y = compiled(x) + self.assertEqual(tuple(y.shape), (batch, 8, 16, 16)) + + +class TestDynamicShapesKernelStrategyInvalidValue(TestCase): + """Invalid strategy names are rejected at TorchTensorRTModule.__init__ on any backend.""" + + @parameterized.expand(_RUNTIMES) + def test_invalid_strategy_raises(self, _name, use_python_runtime): + _skip_if_cpp_unavailable(self, use_python_runtime) + model = DynamicConvModel().eval().cuda() + inp = torchtrt.Input( + min_shape=(1, 3, 16, 16), + opt_shape=(2, 3, 16, 16), + max_shape=(4, 3, 16, 16), + dtype=torch.float32, + ) + with self.assertRaises((ValueError, RuntimeError)): + torchtrt.compile( + model, + ir="dynamo", + inputs=[inp], + enabled_precisions={torch.float32}, + use_python_runtime=use_python_runtime, + min_block_size=1, + dynamic_shapes_kernel_specialization_strategy="not_a_real_strategy", + ) + + if __name__ == "__main__": run_tests()