Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 36 additions & 1 deletion cpp/include/torch_tensorrt/executorch/TensorRTBackend.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ struct EngineHandle {
TRTUniquePtr<nvinfer1::IRuntime> runtime;
TRTUniquePtr<nvinfer1::ICudaEngine> engine;
TRTUniquePtr<nvinfer1::IExecutionContext> exec_ctx;
cudaStream_t stream = nullptr;
std::vector<std::string> input_binding_names;
std::vector<std::string> output_binding_names;
std::vector<InputProfileBounds> input_profile_bounds;
Expand All @@ -63,6 +62,13 @@ struct EngineHandle {
int device_id = 0;
bool unified_memory = false;
std::mutex mu;
// Makes the skip-sync fast path safe to reuse: TensorRT forbids reconfiguring or
// destroying an execution context while one of its enqueues is in flight, so when
// execute() returns without an end sync it records this event; the next execute()
// and the destructor wait on it before touching exec_ctx. One event/flag pair
// suffices because a handle runs on a single thread at a time.
cudaEvent_t inflight_event = nullptr;
bool inflight_pending = false;

~EngineHandle();
};
Expand All @@ -84,5 +90,34 @@ class TensorRTBackend final : public ::executorch::runtime::BackendInterface {
void destroy(::executorch::runtime::DelegateHandle* handle) const override;
};

// Selects, for the calling thread, the CUDA stream the delegate runs TensorRT on;
// scope it around execution.
//
// Confines inference to a CUDA green context's SM partition when the caller
// passes a cuGreenCtxStreamCreate stream: confinement rides the stream (the green
// context need not be current), and cudaStreamPerThread — the no-guard default —
// is rejected while a green context is current. While active, device-resident
// outputs are left enqueued on the stream (no end sync) to compose with later GPU
// work.
//
// Contract: the stream is on the engine's device and outlives the guard; a handle
// is executed by one thread at a time. On the no-end-sync path (guard active, all
// I/O device-resident) execute() returns with the TensorRT enqueue still in flight
// on the stream; the delegate itself orders the next execute() on, and the
// destruction of, that handle after the work completes (via an internal completion
// event), so the caller need only synchronize the stream before reading
// device-resident outputs.
class CudaStreamGuard {
public:
explicit CudaStreamGuard(cudaStream_t stream);
~CudaStreamGuard();
CudaStreamGuard(const CudaStreamGuard&) = delete;
CudaStreamGuard& operator=(const CudaStreamGuard&) = delete;

private:
cudaStream_t prev_stream_;
bool prev_set_;
};

} // namespace executorch_backend
} // namespace torch_tensorrt
160 changes: 129 additions & 31 deletions cpp/src/torch_tensorrt/executorch/TensorRTBackend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,21 @@ using ::executorch::runtime::Span;
} \
} while (false)

namespace {
thread_local cudaStream_t g_user_stream = nullptr;
thread_local bool g_user_stream_set = false;
} // namespace

CudaStreamGuard::CudaStreamGuard(cudaStream_t stream) : prev_stream_(g_user_stream), prev_set_(g_user_stream_set) {
g_user_stream = stream;
g_user_stream_set = true;
}

CudaStreamGuard::~CudaStreamGuard() {
g_user_stream = prev_stream_;
g_user_stream_set = prev_set_;
}

void TRTLogger::log(Severity severity, const char* msg) noexcept {
if (severity <= Severity::kERROR) {
ET_LOG(Error, "TensorRT: %s", msg);
Expand All @@ -58,8 +73,23 @@ void TRTLogger::log(Severity severity, const char* msg) noexcept {

EngineHandle::~EngineHandle() {
cudaSetDevice(device_id);
if (stream != nullptr) {
cudaStreamSynchronize(stream);
// A fast-path execute() may have returned with its enqueue still in flight on the
// caller's stream, still using exec_ctx and the cached staging buffers. Wait on
// the recorded completion event before destroying the context or freeing the
// buffers. We wait on the event, not the stream, so this stays valid even if the
// caller already destroyed the stream. Non-skip executes synchronized inline, so
// inflight_pending is false there. Fall back to a device sync if no event exists.
if (inflight_event != nullptr) {
if (inflight_pending) {
cudaError_t err = cudaEventSynchronize(inflight_event);
if (err != cudaSuccess) {
ET_LOG(Error, "EngineHandle::~EngineHandle: cudaEventSynchronize failed: %s", cudaGetErrorString(err));
cudaGetLastError(); // clear sticky error; tear down regardless
}
inflight_pending = false;
}
} else {
cudaDeviceSynchronize();
}
for (void* p : cached_input_ptrs) {
if (p != nullptr) {
Expand All @@ -74,9 +104,9 @@ EngineHandle::~EngineHandle() {
exec_ctx.reset();
engine.reset();
runtime.reset();
if (stream != nullptr) {
cudaStreamDestroy(stream);
stream = nullptr;
if (inflight_event != nullptr) {
cudaEventDestroy(inflight_event);
inflight_event = nullptr;
}
}

Expand Down Expand Up @@ -226,9 +256,12 @@ Result<DelegateHandle*> TensorRTBackend::init(
return Error::InvalidProgram;
}

cuda_err = cudaStreamCreate(&handle->stream);
// Created while device_id is current so the event belongs to the engine's device.
// It orders a later execute()/teardown after a skip-sync enqueue (see execute()
// and ~EngineHandle). Blocking-sync so the host yields instead of busy-spinning.
cuda_err = cudaEventCreateWithFlags(&handle->inflight_event, cudaEventDisableTiming | cudaEventBlockingSync);
if (cuda_err != cudaSuccess) {
ET_LOG(Error, "TensorRTBackend::init: cudaStreamCreate failed: %s", cudaGetErrorString(cuda_err));
ET_LOG(Error, "TensorRTBackend::init: cudaEventCreateWithFlags failed: %s", cudaGetErrorString(cuda_err));
return Error::InvalidProgram;
}

Expand Down Expand Up @@ -298,22 +331,57 @@ Error TensorRTBackend::execute(BackendExecutionContext& context, DelegateHandle*
return Error::InvalidArgument;
}

cudaError_t cuda_err = cudaSetDevice(engine->device_id);
int entry_device = -1;
cudaError_t cuda_err = cudaGetDevice(&entry_device);
if (cuda_err != cudaSuccess) {
ET_LOG(
Error,
"TensorRTBackend::execute: cudaSetDevice(%d) failed: %s",
engine->device_id,
cudaGetErrorString(cuda_err));
ET_LOG(Error, "TensorRTBackend::execute: cudaGetDevice failed: %s", cudaGetErrorString(cuda_err));
return Error::InvalidProgram;
}
// Put the engine on its own device for multi-GPU correctness, restoring the
// caller's device on exit; green-context confinement rides the selected stream,
// independent of the current device/context.
const bool switch_device = (entry_device != engine->device_id);
if (switch_device) {
cuda_err = cudaSetDevice(engine->device_id);
if (cuda_err != cudaSuccess) {
ET_LOG(
Error,
"TensorRTBackend::execute: cudaSetDevice(%d) failed: %s",
engine->device_id,
cudaGetErrorString(cuda_err));
return Error::InvalidProgram;
}
}
struct DeviceRestore {
int device;
bool active;
~DeviceRestore() {
if (active) {
cudaSetDevice(device);
}
}
} device_restore{entry_device, switch_device};

std::unique_lock<std::mutex> lock(engine->mu);

nvinfer1::IExecutionContext* ctx = engine->exec_ctx.get();
cudaStream_t stream = engine->stream;
TORCHTRT_ET_CHECK_NOT_NULL(ctx, Error::InvalidState, "TensorRTBackend::execute: backend is not initialized");
TORCHTRT_ET_CHECK_NOT_NULL(stream, Error::InvalidState, "TensorRTBackend::execute: backend is not initialized");

// A prior fast-path execute() may have returned with its enqueue still in flight
// on the shared exec_ctx. Wait for it before reconfiguring the context below:
// TensorRT forbids mutating a context while one of its enqueues is in flight, and
// setInputShape/setTensorAddress run on the host, so this must be a host-side wait.
if (engine->inflight_pending) {
cuda_err = cudaEventSynchronize(engine->inflight_event);
engine->inflight_pending = false;
if (cuda_err != cudaSuccess) {
ET_LOG(Error, "TensorRTBackend::execute: cudaEventSynchronize failed: %s", cudaGetErrorString(cuda_err));
return Error::InvalidProgram;
}
}
cudaStream_t stream = g_user_stream_set ? g_user_stream : cudaStreamPerThread;
bool output_staged_to_host = false;
bool input_staged_from_host = false;

if (engine->cached_input_ptrs.empty()) {
engine->cached_input_ptrs.resize(num_inputs, nullptr);
Expand Down Expand Up @@ -392,6 +460,7 @@ Error TensorRTBackend::execute(BackendExecutionContext& context, DelegateHandle*
engine->cached_input_sizes[i] = needed;
}
bind_ptr = engine->cached_input_ptrs[i];
input_staged_from_host = true;
cuda_err = cudaMemcpyAsync(bind_ptr, et_in.const_data_ptr(), needed, cudaMemcpyHostToDevice, stream);
if (cuda_err != cudaSuccess) {
ET_LOG(
Expand Down Expand Up @@ -486,6 +555,7 @@ Error TensorRTBackend::execute(BackendExecutionContext& context, DelegateHandle*
engine->cached_output_sizes[o] = needed;
}
bind_ptr = engine->cached_output_ptrs[o];
output_staged_to_host = true;
outputs_needing_copy.push_back({o, bind_ptr});
}

Expand All @@ -499,28 +569,56 @@ Error TensorRTBackend::execute(BackendExecutionContext& context, DelegateHandle*
// 4. Enqueue inference on the current CUDA stream
// ------------------------------------------------------------------
if (!ctx->enqueueV3(stream)) {
ET_LOG(Error, "TensorRTBackend::execute: enqueueV3 failed");
ET_LOG(
Error,
"TensorRTBackend::execute: enqueueV3 failed. If a CUDA green context is "
"current, scope a CudaStreamGuard with a green-context stream: "
"cudaStreamPerThread is invalid while a green context is current.");
return Error::InvalidState;
}

for (auto& output : outputs_needing_copy) {
exec_aten::Tensor et_out = args[num_inputs + output.first]->toTensor();
cuda_err =
cudaMemcpyAsync(et_out.mutable_data_ptr(), output.second, et_out.nbytes(), cudaMemcpyDeviceToHost, stream);
// The engine work is now in flight on `stream`. Decide whether to wait for it:
// must_sync = an output is staged to host (the caller reads the D2H result on
// return), an input was staged from host (its async H2D read the caller's host
// buffer, which the caller may reuse once we return), or no caller stream is
// active (preserve the historical "results ready on return" behavior).
// Otherwise (caller stream + all I/O device-resident) leave the work enqueued so
// it composes with the caller's later GPU work, and record inflight_event so the
// next execute() and the destructor wait before reusing/freeing exec_ctx. The D2H
// copies live in the must_sync branch: an output staged to host always sets
// output_staged_to_host, so outputs_needing_copy is empty on the skip path.
const bool must_sync = output_staged_to_host || input_staged_from_host || !g_user_stream_set;
if (must_sync) {
for (auto& output : outputs_needing_copy) {
exec_aten::Tensor et_out = args[num_inputs + output.first]->toTensor();
cuda_err =
cudaMemcpyAsync(et_out.mutable_data_ptr(), output.second, et_out.nbytes(), cudaMemcpyDeviceToHost, stream);
if (cuda_err != cudaSuccess) {
ET_LOG(
Error,
"TensorRTBackend::execute: D2H copy failed for output %zu: %s",
output.first,
cudaGetErrorString(cuda_err));
return Error::InvalidProgram;
}
}
cuda_err = cudaStreamSynchronize(stream);
engine->inflight_pending = false;
if (cuda_err != cudaSuccess) {
ET_LOG(
Error,
"TensorRTBackend::execute: D2H copy failed for output %zu: %s",
output.first,
cudaGetErrorString(cuda_err));
ET_LOG(Error, "TensorRTBackend::execute: cudaStreamSynchronize failed: %s", cudaGetErrorString(cuda_err));
return Error::InvalidProgram;
}
}

cuda_err = cudaStreamSynchronize(stream);
if (cuda_err != cudaSuccess) {
ET_LOG(Error, "TensorRTBackend::execute: cudaStreamSynchronize failed: %s", cudaGetErrorString(cuda_err));
return Error::InvalidProgram;
} else {
cuda_err = cudaEventRecord(engine->inflight_event, stream);
if (cuda_err != cudaSuccess) {
// Could not arm the completion marker; drain now so a later execute() or the
// destructor never reconfigures or frees exec_ctx while this enqueue runs.
ET_LOG(Error, "TensorRTBackend::execute: cudaEventRecord failed: %s", cudaGetErrorString(cuda_err));
(void)cudaStreamSynchronize(stream);
engine->inflight_pending = false;
return Error::InvalidProgram;
}
engine->inflight_pending = true;
}
return Error::Ok;
}
Expand Down
Loading