Skip to content

Commit 1957cc4

Browse files
committed
feat: Support exported and serialization workflows for MD-TRT
1 parent 6f81a66 commit 1957cc4

32 files changed

Lines changed: 2634 additions & 280 deletions

.github/workflows/build-test-linux-x86_64.yml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -526,7 +526,12 @@ jobs:
526526
pushd .
527527
cd tests/py
528528
cd dynamo
529-
python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/l2_dynamo_distributed_test_results.xml distributed/test_nccl_ops.py
529+
python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/l2_dynamo_distributed_test_results.xml \
530+
distributed/test_nccl_ops.py \
531+
distributed/test_native_nccl.py \
532+
distributed/test_export_save_load.py
533+
torchrun --nproc_per_node=2 distributed/test_native_nccl.py --multirank
534+
torchrun --nproc_per_node=2 distributed/test_export_save_load.py --multirank
530535
popd
531536
532537
concurrency:

core/runtime/TRTEngine.cpp

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,15 @@ TRTEngine::TRTEngine(
290290
TRTEngine::~TRTEngine() {
291291
torch::cuda::synchronize(device_info.id);
292292
trt_engine_profiler.reset();
293+
#ifdef ENABLE_TRT_NCCL_COLLECTIVES
294+
// Null out the NCCL communicator before destroying the execution context.
295+
// dist.destroy_process_group() may have already freed the ncclComm_t; if we
296+
// let IExecutionContext::~IExecutionContext() run with a dangling pointer it
297+
// will segfault.
298+
if (nccl_initialized && exec_ctx) {
299+
exec_ctx->setCommunicator(nullptr);
300+
}
301+
#endif
293302
exec_ctx.reset();
294303
cuda_engine.reset();
295304
if (empty_tensor_placeholder) {
@@ -554,6 +563,35 @@ void TRTEngine::set_resource_allocation_strategy(TRTEngine::ResourceAllocationSt
554563

555564
#ifdef ENABLE_TRT_NCCL_COLLECTIVES
556565
bool TRTEngine::bind_nccl_comm() {
566+
// When group_name is empty (e.g. engine loaded from a serialized
567+
// ExportedProgram where the Python TorchTensorRTModule wrapper was
568+
// inlined and set_group_name() was never called), auto-resolve the
569+
// process group from the c10d registry. PyTorch assigns sequential
570+
// numeric names ("0", "1", ...) to process groups; probe until we
571+
// find one with an NCCL backend.
572+
if (this->group_name.empty() && this->is_md) {
573+
// PyTorch assigns sequential numeric names ("0", "1", ...) to process
574+
// groups. In practice most jobs create fewer than 10 groups; we probe
575+
// up to 20 to allow for destroyed-and-recreated groups.
576+
for (int i = 0; i < 20; ++i) {
577+
auto candidate = std::to_string(i);
578+
auto probe = c10d::resolve_process_group(candidate);
579+
if (probe != nullptr &&
580+
probe->getBackendType() == c10d::ProcessGroup::BackendType::NCCL) {
581+
this->group_name = candidate;
582+
LOG_INFO("Auto-resolved distributed group name to '" << candidate << "'");
583+
break;
584+
}
585+
}
586+
if (this->group_name.empty()) {
587+
LOG_WARNING(
588+
"This TRT engine requires NCCL (is_md=true) but no NCCL process group "
589+
"was found in the c10d registry. Ensure dist.init_process_group(backend='nccl') "
590+
"has been called before loading the engine. You can also set the group name "
591+
"manually via: engine.set_group_name(NCCL_GROUP_NAME)");
592+
}
593+
}
594+
557595
// Soft-return when the process group isn't available yet (e.g. at engine
558596
// construction time when the caller hasn't called dist.init_process_group()).
559597
auto pg = c10d::resolve_process_group(this->group_name);

core/runtime/execute_engine.cpp

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -330,22 +330,6 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
330330
std::make_unique<torch::autograd::profiler::RecordProfile>(compiled_engine->enqueue_profile_path);
331331
}
332332

333-
// Distributed setup - set NCCL communicator on TensorRT execution context
334-
#ifdef ENABLE_TRT_NCCL_COLLECTIVES
335-
if (compiled_engine->rank >= 0 && compiled_engine->world_size > 1) {
336-
bool result = compiled_engine->set_nccl_communicator_to_trt_context();
337-
if (!result) {
338-
LOG_ERROR("Failed to set NCCL communicator on TRT context");
339-
LOG_ERROR("This will cause collective operations to fail at runtime");
340-
LOG_ERROR("Make sure to call module.init_nccl_comm() after compilation");
341-
}
342-
} else {
343-
LOG_DEBUG(
344-
"Single-device mode (rank=" << compiled_engine->rank << ", world_size=" << compiled_engine->world_size
345-
<< ") - skipping NCCL setup");
346-
}
347-
#endif
348-
349333
// Block engine stream until results are available on caller stream
350334
at::cuda::CUDAEvent caller_exec_complete;
351335
caller_exec_complete.record(compiled_engine->caller_stream);

core/runtime/register_jit_hooks.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,26 @@ static auto TORCHTRT_UNUSED TRTEngineTSRegistrtion =
126126
})
127127
.def("bind_nccl_comm", [](c10::intrusive_ptr<TRTEngine> self) { self->bind_nccl_comm(); })
128128
.def_readonly("nccl_initialized", &TRTEngine::nccl_initialized)
129+
#else
130+
.def(
131+
"set_group_name",
132+
[](c10::intrusive_ptr<TRTEngine> self, std::string group_name) {
133+
LOG_ERROR(
134+
"This build does not support MultiDevice TensorRT (ENABLE_TRT_NCCL_COLLECTIVES is OFF); set_group_name is a no-op");
135+
})
136+
.def(
137+
"bind_nccl_comm",
138+
[](c10::intrusive_ptr<TRTEngine> self) {
139+
LOG_ERROR(
140+
"This build does not support MultiDevice TensorRT (ENABLE_TRT_NCCL_COLLECTIVES is OFF); bind_nccl_comm is a no-op");
141+
})
142+
.def_property_readonly(
143+
"nccl_initialized",
144+
[](c10::intrusive_ptr<TRTEngine> self) -> bool {
145+
LOG_ERROR(
146+
"This build does not support MultiDevice TensorRT (ENABLE_TRT_NCCL_COLLECTIVES is OFF); nccl_initialized always returns false");
147+
return false;
148+
})
129149
#endif
130150
.def_pickle(
131151
[](const c10::intrusive_ptr<TRTEngine>& self) -> std::vector<std::string> { return self->serialize(); },

0 commit comments

Comments
 (0)