From c3a5f44cc5d918b05cb8dc6862990083db8173c4 Mon Sep 17 00:00:00 2001 From: Anthony Shoumikhin Date: Fri, 29 May 2026 19:43:59 -0700 Subject: [PATCH] fix: run the ExecuTorch TensorRT delegate on a caller-selected CUDA stream The delegate created and owned a private CUDA stream in init() and ran every enqueueV3() on it, so an application could not place inference on a specific CUDA stream or context (for example a CUDA green context for SM partitioning). Let the caller select the stream instead, bringing the libtorch-free ExecuTorch runtime the same caller-stream capability the libtorch TensorRT runtime has (#4232): - Add a scoped CudaStreamGuard (mirroring c10::cuda::CUDAStreamGuard) to select, per calling thread, the CUDA stream the delegate runs TensorRT on. With no guard active the delegate runs on cudaStreamPerThread. - execute() runs enqueueV3() and the staging copies on the selected stream; init() no longer creates a stream and the delegate owns none. - To confine inference to a CUDA green context's SM partition the caller scopes a guard with a stream created on that green context (cuGreenCtxStreamCreate); the partition confinement travels with the stream, so the green context need not be made current. cudaStreamPerThread is invalid while a green context is current (cudaErrorInvalidResourceHandle), so a green-context caller must scope a guard. - cudaSetDevice() is applied only when the engine's device differs from the current device and is restored on exit, so it no longer clobbers a context the caller established. - execute() leaves device-resident outputs enqueued (no end sync) only while a guard is active; the default path and host-staged outputs still synchronize before returning, preserving existing behavior. The caller synchronizes the selected stream when it reads device-resident results. - Make the no-sync path safe to reuse: the handle records a CUDA completion event after the enqueue, and the next execute() (and the destructor) waits on it before reconfiguring or freeing the shared IExecutionContext. A handle can thus be run repeatedly on a caller stream without the caller synchronizing between calls, and teardown never frees a context with an enqueue still in flight. No dependency on the libtorch Torch-TensorRT runtime or libtorch is added. --- .../executorch/TensorRTBackend.h | 37 +++- .../executorch/TensorRTBackend.cpp | 160 ++++++++++++++---- 2 files changed, 165 insertions(+), 32 deletions(-) diff --git a/cpp/include/torch_tensorrt/executorch/TensorRTBackend.h b/cpp/include/torch_tensorrt/executorch/TensorRTBackend.h index e67d698d9c..31383f9e17 100644 --- a/cpp/include/torch_tensorrt/executorch/TensorRTBackend.h +++ b/cpp/include/torch_tensorrt/executorch/TensorRTBackend.h @@ -50,7 +50,6 @@ struct EngineHandle { TRTUniquePtr runtime; TRTUniquePtr engine; TRTUniquePtr exec_ctx; - cudaStream_t stream = nullptr; std::vector input_binding_names; std::vector output_binding_names; std::vector input_profile_bounds; @@ -63,6 +62,13 @@ struct EngineHandle { int device_id = 0; bool unified_memory = false; std::mutex mu; + // Makes the skip-sync fast path safe to reuse: TensorRT forbids reconfiguring or + // destroying an execution context while one of its enqueues is in flight, so when + // execute() returns without an end sync it records this event; the next execute() + // and the destructor wait on it before touching exec_ctx. One event/flag pair + // suffices because a handle runs on a single thread at a time. + cudaEvent_t inflight_event = nullptr; + bool inflight_pending = false; ~EngineHandle(); }; @@ -84,5 +90,34 @@ class TensorRTBackend final : public ::executorch::runtime::BackendInterface { void destroy(::executorch::runtime::DelegateHandle* handle) const override; }; +// Selects, for the calling thread, the CUDA stream the delegate runs TensorRT on; +// scope it around execution. +// +// Confines inference to a CUDA green context's SM partition when the caller +// passes a cuGreenCtxStreamCreate stream: confinement rides the stream (the green +// context need not be current), and cudaStreamPerThread — the no-guard default — +// is rejected while a green context is current. While active, device-resident +// outputs are left enqueued on the stream (no end sync) to compose with later GPU +// work. +// +// Contract: the stream is on the engine's device and outlives the guard; a handle +// is executed by one thread at a time. On the no-end-sync path (guard active, all +// I/O device-resident) execute() returns with the TensorRT enqueue still in flight +// on the stream; the delegate itself orders the next execute() on, and the +// destruction of, that handle after the work completes (via an internal completion +// event), so the caller need only synchronize the stream before reading +// device-resident outputs. +class CudaStreamGuard { + public: + explicit CudaStreamGuard(cudaStream_t stream); + ~CudaStreamGuard(); + CudaStreamGuard(const CudaStreamGuard&) = delete; + CudaStreamGuard& operator=(const CudaStreamGuard&) = delete; + + private: + cudaStream_t prev_stream_; + bool prev_set_; +}; + } // namespace executorch_backend } // namespace torch_tensorrt diff --git a/cpp/src/torch_tensorrt/executorch/TensorRTBackend.cpp b/cpp/src/torch_tensorrt/executorch/TensorRTBackend.cpp index c19bb3754f..b2e3b08232 100644 --- a/cpp/src/torch_tensorrt/executorch/TensorRTBackend.cpp +++ b/cpp/src/torch_tensorrt/executorch/TensorRTBackend.cpp @@ -48,6 +48,21 @@ using ::executorch::runtime::Span; } \ } while (false) +namespace { +thread_local cudaStream_t g_user_stream = nullptr; +thread_local bool g_user_stream_set = false; +} // namespace + +CudaStreamGuard::CudaStreamGuard(cudaStream_t stream) : prev_stream_(g_user_stream), prev_set_(g_user_stream_set) { + g_user_stream = stream; + g_user_stream_set = true; +} + +CudaStreamGuard::~CudaStreamGuard() { + g_user_stream = prev_stream_; + g_user_stream_set = prev_set_; +} + void TRTLogger::log(Severity severity, const char* msg) noexcept { if (severity <= Severity::kERROR) { ET_LOG(Error, "TensorRT: %s", msg); @@ -58,8 +73,23 @@ void TRTLogger::log(Severity severity, const char* msg) noexcept { EngineHandle::~EngineHandle() { cudaSetDevice(device_id); - if (stream != nullptr) { - cudaStreamSynchronize(stream); + // A fast-path execute() may have returned with its enqueue still in flight on the + // caller's stream, still using exec_ctx and the cached staging buffers. Wait on + // the recorded completion event before destroying the context or freeing the + // buffers. We wait on the event, not the stream, so this stays valid even if the + // caller already destroyed the stream. Non-skip executes synchronized inline, so + // inflight_pending is false there. Fall back to a device sync if no event exists. + if (inflight_event != nullptr) { + if (inflight_pending) { + cudaError_t err = cudaEventSynchronize(inflight_event); + if (err != cudaSuccess) { + ET_LOG(Error, "EngineHandle::~EngineHandle: cudaEventSynchronize failed: %s", cudaGetErrorString(err)); + cudaGetLastError(); // clear sticky error; tear down regardless + } + inflight_pending = false; + } + } else { + cudaDeviceSynchronize(); } for (void* p : cached_input_ptrs) { if (p != nullptr) { @@ -74,9 +104,9 @@ EngineHandle::~EngineHandle() { exec_ctx.reset(); engine.reset(); runtime.reset(); - if (stream != nullptr) { - cudaStreamDestroy(stream); - stream = nullptr; + if (inflight_event != nullptr) { + cudaEventDestroy(inflight_event); + inflight_event = nullptr; } } @@ -226,9 +256,12 @@ Result TensorRTBackend::init( return Error::InvalidProgram; } - cuda_err = cudaStreamCreate(&handle->stream); + // Created while device_id is current so the event belongs to the engine's device. + // It orders a later execute()/teardown after a skip-sync enqueue (see execute() + // and ~EngineHandle). Blocking-sync so the host yields instead of busy-spinning. + cuda_err = cudaEventCreateWithFlags(&handle->inflight_event, cudaEventDisableTiming | cudaEventBlockingSync); if (cuda_err != cudaSuccess) { - ET_LOG(Error, "TensorRTBackend::init: cudaStreamCreate failed: %s", cudaGetErrorString(cuda_err)); + ET_LOG(Error, "TensorRTBackend::init: cudaEventCreateWithFlags failed: %s", cudaGetErrorString(cuda_err)); return Error::InvalidProgram; } @@ -298,22 +331,57 @@ Error TensorRTBackend::execute(BackendExecutionContext& context, DelegateHandle* return Error::InvalidArgument; } - cudaError_t cuda_err = cudaSetDevice(engine->device_id); + int entry_device = -1; + cudaError_t cuda_err = cudaGetDevice(&entry_device); if (cuda_err != cudaSuccess) { - ET_LOG( - Error, - "TensorRTBackend::execute: cudaSetDevice(%d) failed: %s", - engine->device_id, - cudaGetErrorString(cuda_err)); + ET_LOG(Error, "TensorRTBackend::execute: cudaGetDevice failed: %s", cudaGetErrorString(cuda_err)); return Error::InvalidProgram; } + // Put the engine on its own device for multi-GPU correctness, restoring the + // caller's device on exit; green-context confinement rides the selected stream, + // independent of the current device/context. + const bool switch_device = (entry_device != engine->device_id); + if (switch_device) { + cuda_err = cudaSetDevice(engine->device_id); + if (cuda_err != cudaSuccess) { + ET_LOG( + Error, + "TensorRTBackend::execute: cudaSetDevice(%d) failed: %s", + engine->device_id, + cudaGetErrorString(cuda_err)); + return Error::InvalidProgram; + } + } + struct DeviceRestore { + int device; + bool active; + ~DeviceRestore() { + if (active) { + cudaSetDevice(device); + } + } + } device_restore{entry_device, switch_device}; std::unique_lock lock(engine->mu); nvinfer1::IExecutionContext* ctx = engine->exec_ctx.get(); - cudaStream_t stream = engine->stream; TORCHTRT_ET_CHECK_NOT_NULL(ctx, Error::InvalidState, "TensorRTBackend::execute: backend is not initialized"); - TORCHTRT_ET_CHECK_NOT_NULL(stream, Error::InvalidState, "TensorRTBackend::execute: backend is not initialized"); + + // A prior fast-path execute() may have returned with its enqueue still in flight + // on the shared exec_ctx. Wait for it before reconfiguring the context below: + // TensorRT forbids mutating a context while one of its enqueues is in flight, and + // setInputShape/setTensorAddress run on the host, so this must be a host-side wait. + if (engine->inflight_pending) { + cuda_err = cudaEventSynchronize(engine->inflight_event); + engine->inflight_pending = false; + if (cuda_err != cudaSuccess) { + ET_LOG(Error, "TensorRTBackend::execute: cudaEventSynchronize failed: %s", cudaGetErrorString(cuda_err)); + return Error::InvalidProgram; + } + } + cudaStream_t stream = g_user_stream_set ? g_user_stream : cudaStreamPerThread; + bool output_staged_to_host = false; + bool input_staged_from_host = false; if (engine->cached_input_ptrs.empty()) { engine->cached_input_ptrs.resize(num_inputs, nullptr); @@ -392,6 +460,7 @@ Error TensorRTBackend::execute(BackendExecutionContext& context, DelegateHandle* engine->cached_input_sizes[i] = needed; } bind_ptr = engine->cached_input_ptrs[i]; + input_staged_from_host = true; cuda_err = cudaMemcpyAsync(bind_ptr, et_in.const_data_ptr(), needed, cudaMemcpyHostToDevice, stream); if (cuda_err != cudaSuccess) { ET_LOG( @@ -486,6 +555,7 @@ Error TensorRTBackend::execute(BackendExecutionContext& context, DelegateHandle* engine->cached_output_sizes[o] = needed; } bind_ptr = engine->cached_output_ptrs[o]; + output_staged_to_host = true; outputs_needing_copy.push_back({o, bind_ptr}); } @@ -499,28 +569,56 @@ Error TensorRTBackend::execute(BackendExecutionContext& context, DelegateHandle* // 4. Enqueue inference on the current CUDA stream // ------------------------------------------------------------------ if (!ctx->enqueueV3(stream)) { - ET_LOG(Error, "TensorRTBackend::execute: enqueueV3 failed"); + ET_LOG( + Error, + "TensorRTBackend::execute: enqueueV3 failed. If a CUDA green context is " + "current, scope a CudaStreamGuard with a green-context stream: " + "cudaStreamPerThread is invalid while a green context is current."); return Error::InvalidState; } - for (auto& output : outputs_needing_copy) { - exec_aten::Tensor et_out = args[num_inputs + output.first]->toTensor(); - cuda_err = - cudaMemcpyAsync(et_out.mutable_data_ptr(), output.second, et_out.nbytes(), cudaMemcpyDeviceToHost, stream); + // The engine work is now in flight on `stream`. Decide whether to wait for it: + // must_sync = an output is staged to host (the caller reads the D2H result on + // return), an input was staged from host (its async H2D read the caller's host + // buffer, which the caller may reuse once we return), or no caller stream is + // active (preserve the historical "results ready on return" behavior). + // Otherwise (caller stream + all I/O device-resident) leave the work enqueued so + // it composes with the caller's later GPU work, and record inflight_event so the + // next execute() and the destructor wait before reusing/freeing exec_ctx. The D2H + // copies live in the must_sync branch: an output staged to host always sets + // output_staged_to_host, so outputs_needing_copy is empty on the skip path. + const bool must_sync = output_staged_to_host || input_staged_from_host || !g_user_stream_set; + if (must_sync) { + for (auto& output : outputs_needing_copy) { + exec_aten::Tensor et_out = args[num_inputs + output.first]->toTensor(); + cuda_err = + cudaMemcpyAsync(et_out.mutable_data_ptr(), output.second, et_out.nbytes(), cudaMemcpyDeviceToHost, stream); + if (cuda_err != cudaSuccess) { + ET_LOG( + Error, + "TensorRTBackend::execute: D2H copy failed for output %zu: %s", + output.first, + cudaGetErrorString(cuda_err)); + return Error::InvalidProgram; + } + } + cuda_err = cudaStreamSynchronize(stream); + engine->inflight_pending = false; if (cuda_err != cudaSuccess) { - ET_LOG( - Error, - "TensorRTBackend::execute: D2H copy failed for output %zu: %s", - output.first, - cudaGetErrorString(cuda_err)); + ET_LOG(Error, "TensorRTBackend::execute: cudaStreamSynchronize failed: %s", cudaGetErrorString(cuda_err)); return Error::InvalidProgram; } - } - - cuda_err = cudaStreamSynchronize(stream); - if (cuda_err != cudaSuccess) { - ET_LOG(Error, "TensorRTBackend::execute: cudaStreamSynchronize failed: %s", cudaGetErrorString(cuda_err)); - return Error::InvalidProgram; + } else { + cuda_err = cudaEventRecord(engine->inflight_event, stream); + if (cuda_err != cudaSuccess) { + // Could not arm the completion marker; drain now so a later execute() or the + // destructor never reconfigures or frees exec_ctx while this enqueue runs. + ET_LOG(Error, "TensorRTBackend::execute: cudaEventRecord failed: %s", cudaGetErrorString(cuda_err)); + (void)cudaStreamSynchronize(stream); + engine->inflight_pending = false; + return Error::InvalidProgram; + } + engine->inflight_pending = true; } return Error::Ok; }