Skip to content

multidevice: enable NCCL Copy Engine allgather via CTAPolicy=ZERO#6046

Open
saivishal1999 wants to merge 1 commit intomainfrom
nccl-ce-allgather
Open

multidevice: enable NCCL Copy Engine allgather via CTAPolicy=ZERO#6046
saivishal1999 wants to merge 1 commit intomainfrom
nccl-ce-allgather

Conversation

@saivishal1999
Copy link
Copy Markdown
Collaborator

When SymmetricMemoryBackend::PyTorchNccl is active, configure the ProcessGroupNCCL with NCCL_CTA_POLICY_ZERO at creation time. Combined with NCCL-window-registered buffers (allocated via empty_strided_p2p and rendezvous'd through SymmetricTensor::setupRemoteHandles), this causes NCCL to select the Copy Engine (DMA) path for allgather, freeing SM/CTA resources.

Add SymmetricTensorTest.CopyEngineAllgather to verify correctness of an allgather over symm_mem buffers on all 8 ranks. Confirmed via NCCL logs:

  • Comm config CTA policy flags set to 2 on each rank
  • Inserted window ... into address map for both input/output buffers
  • AllGather [Copy Engine]: ... -> cudaMemcpy; CE synchronization with NVLS

When `SymmetricMemoryBackend::PyTorchNccl` is active, configure the
ProcessGroupNCCL with `NCCL_CTA_POLICY_ZERO` at creation time. Combined
with NCCL-window-registered buffers (allocated via `empty_strided_p2p`
and rendezvous'd through `SymmetricTensor::setupRemoteHandles`), this
causes NCCL to select the Copy Engine (DMA) path for allgather, freeing
SM/CTA resources.

Add `SymmetricTensorTest.CopyEngineAllgather` to verify correctness of
an allgather over symm_mem buffers on all 8 ranks. Confirmed via NCCL
logs:
  - `Comm config CTA policy flags set to 2` on each rank
  - `Inserted window ... into address map` for both input/output buffers
  - `AllGather [Copy Engine]: ... -> cudaMemcpy; CE synchronization with NVLS`
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 6, 2026

Greptile Summary

Configures each ProcessGroupNCCL with NCCL_CTA_POLICY_ZERO at creation time when the PyTorchNccl symmetric-memory backend is active, and adds a correctness test for an allgather over window-registered symmetric-memory buffers. The CTA policy configuration is correctly guarded by #ifdef NCCL_HAS_CTA_POLICY and the required ipc_utils.h include is added.

  • communicator.cpp: The CTAPolicy=ZERO is applied inside createBackend, which is called lazily for every team key; this means sub-team PGs (e.g., for tensor or pipeline parallelism) will also carry the policy when PyTorchNccl is active.
  • test_multidevice_symmetric_tensor.cpp: New CopyEngineAllgather test allocates input/output via SymmetricTensor::allocate, rendezvous both buffers, runs _allgather_base, and validates each rank's slice; correctness of the CE path itself is verified only via NCCL log messages (noted in the PR description).

Confidence Score: 4/5

Safe to merge; the CTA policy change is opt-in via the PyTorchNccl backend flag and correctness is preserved.

The change is small and well-guarded by #ifdef NCCL_HAS_CTA_POLICY. The one thing worth a second look is that CTAPolicy=ZERO is stamped onto every NCCL PG created while PyTorchNccl is active, not only the world communicator. Workloads mixing symmetric-memory allgather with sub-team collectives may see performance changes on those sub-team groups.

csrc/multidevice/communicator.cpp — the scope of which PGs receive CTAPolicy=ZERO deserves a brief look before merging into workloads that use multi-team parallelism strategies.

Important Files Changed

Filename Overview
csrc/multidevice/communicator.cpp Adds include for ipc_utils.h and sets NCCL_CTA_POLICY_ZERO on every NCCL PG when PyTorchNccl backend is active; the policy is applied broadly to all teams, not just the world communicator.
tests/cpp/test_multidevice_symmetric_tensor.cpp Adds CopyEngineAllgather test; correctly gates on PyTorchNccl + NCCL + multi-rank, allocates symm_mem buffers, rendezvous both, fills input, runs allgather, validates output — minor redundancy in manual team construction instead of getWorld().

Sequence Diagram

sequenceDiagram
    participant Test as CopyEngineAllgather Test
    participant Alloc as SymmetricTensor::allocate
    participant Backend as createBackend (NCCL PG)
    participant Rdv as symmetric_memory::rendezvous
    participant NCCL as _allgather_base (CE path)

    Test->>Alloc: allocate(input shape)
    Alloc->>Backend: initSymmMemBackendAndGetGroup
    Backend-->>Backend: CTAPolicy=ZERO (NCCL_HAS_CTA_POLICY guard)
    Test->>Alloc: allocate(output shape)
    Note over Alloc: returns cached PG

    Test->>Rdv: input_sym.setupRemoteHandles
    Note over Rdv: NCCL window-registers input buffer
    Test->>Rdv: output_sym.setupRemoteHandles
    Note over Rdv: NCCL window-registers output buffer

    Test->>Test: input.fill_(rank + 1)
    Test->>NCCL: _allgather_base(output, input)
    Note over NCCL: CTAPolicy=ZERO + window-registered = CE DMA path
    NCCL-->>Test: work->wait()
    Test->>Test: validate output slices
Loading

Reviews (1): Last reviewed commit: "multidevice: enable NCCL Copy Engine all..." | Re-trigger Greptile

Comment on lines +160 to +163
#ifdef NCCL_HAS_CTA_POLICY
if (getSymmetricMemoryBackend() == SymmetricMemoryBackend::PyTorchNccl) {
pg_opts->config.CTAPolicy = NCCL_CTA_POLICY_ZERO;
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 CTAPolicy=ZERO applied to every NCCL PG, not just the world PG

createBackend is called lazily by getBackendForTeam whenever any new team key is created — including sub-team PGs (e.g., tensor-parallel or pipeline-parallel groups). When PyTorchNccl is active, every future NCCL process group will be created with CTAPolicy=ZERO, not only the world communicator used for the symmetric-memory allgather. Operations on those sub-team groups (allreduce, reduce-scatter, broadcast, etc.) that cannot use the CE path will be executed under the ZERO-CTA policy, which may meaningfully change their latency/throughput relative to the default policy.

Comment on lines +431 to +434
Team all_ranks(world_size);
std::iota(all_ranks.begin(), all_ranks.end(), 0);
c10d::Backend* backend =
communicator_->getBackendForTeam(all_ranks, CommunicatorBackend::kNccl);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 The test manually replicates the logic inside Communicator::getWorld() — constructing a full-rank Team and calling getBackendForTeam. Using getWorld directly is cleaner and is the existing idiom in the codebase.

Suggested change
Team all_ranks(world_size);
std::iota(all_ranks.begin(), all_ranks.end(), 0);
c10d::Backend* backend =
communicator_->getBackendForTeam(all_ranks, CommunicatorBackend::kNccl);
c10d::Backend* backend = communicator_->getWorld(CommunicatorBackend::kNccl);

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant