diff --git a/infini_train/include/core/ccl/ccl.h b/infini_train/include/core/ccl/ccl.h index 626cb078..23aa32b4 100644 --- a/infini_train/include/core/ccl/ccl.h +++ b/infini_train/include/core/ccl/ccl.h @@ -53,6 +53,9 @@ class CclImpl { nn::parallel::function::ReduceOpType reduce_op, const CclComm *comm, Stream *stream) const; + virtual void AlltoAll(const void *sendbuff, void *recvbuff, size_t count, DataType dtype, const CclComm *comm, + Stream *stream) const; + virtual void Send(const void *buff, size_t count, DataType dtype, int peer, const CclComm *comm, Stream *stream) const; diff --git a/infini_train/include/nn/parallel/parallel_functional.h b/infini_train/include/nn/parallel/parallel_functional.h index 2eed56f4..a5641d34 100644 --- a/infini_train/include/nn/parallel/parallel_functional.h +++ b/infini_train/include/nn/parallel/parallel_functional.h @@ -25,6 +25,9 @@ std::shared_ptr AllGather(const std::shared_ptr &output, const std std::shared_ptr ReduceScatter(const std::shared_ptr &output, const std::shared_ptr &input, ReduceOpType reduce_op, const ProcessGroup *pg = nullptr, bool async_op = false); +std::shared_ptr AlltoAll(const std::shared_ptr &output, const std::shared_ptr &input, + const ProcessGroup *pg = nullptr, bool async_op = false); + std::vector>> Scatter(const std::vector> &input_tensors, const std::vector &device_ids, int dim); diff --git a/infini_train/include/nn/parallel/process_group.h b/infini_train/include/nn/parallel/process_group.h index 74bf80c6..7db4c2b5 100644 --- a/infini_train/include/nn/parallel/process_group.h +++ b/infini_train/include/nn/parallel/process_group.h @@ -30,6 +30,17 @@ class Work; namespace infini_train::nn::parallel { +enum class P2POpType { + kSend, + kRecv, +}; + +struct P2POp { + P2POpType type; + std::shared_ptr tensor; + int peer_rank; +}; + class ProcessGroup { public: explicit ProcessGroup(Device::DeviceType backend, const std::string &process_group_name, @@ -52,12 +63,17 @@ class ProcessGroup { function::ReduceOpType reduce_op = function::ReduceOpType::kSum, bool async_op = false) const; + virtual std::shared_ptr AlltoAll(const std::shared_ptr &output, const std::shared_ptr &input, + bool async_op = false) const; + virtual std::shared_ptr Send(std::vector> tensors, int dest_rank, bool async_op = false) const; virtual std::shared_ptr Recv(std::vector> tensors, int src_rank, bool async_op = false) const; + virtual std::shared_ptr BatchSendRecv(const std::vector &ops, bool async_op = false) const; + // Legacy communication APIs (Single-stream) virtual std::vector> BroadCast(const std::vector> &input_tensors) const; diff --git a/infini_train/src/core/ccl/ccl.cc b/infini_train/src/core/ccl/ccl.cc index 92c14cc6..ccb33e3f 100644 --- a/infini_train/src/core/ccl/ccl.cc +++ b/infini_train/src/core/ccl/ccl.cc @@ -54,6 +54,11 @@ void CclImpl::ReduceScatter(const void *sendbuff, void *recvbuff, size_t recv_co LOG(FATAL) << "CclImpl::ReduceScatter is not implemented."; } +void CclImpl::AlltoAll(const void *sendbuff, void *recvbuff, size_t count, DataType dtype, const CclComm *comm, + Stream *stream) const { + LOG(FATAL) << "CclImpl::AlltoAll is not implemented."; +} + void CclImpl::Send(const void *buff, size_t count, DataType dtype, int peer, const CclComm *comm, Stream *stream) const { LOG(FATAL) << "CclImpl::Send is not implemented."; diff --git a/infini_train/src/core/ccl/cuda/nccl_impl.cc b/infini_train/src/core/ccl/cuda/nccl_impl.cc index 9e4b1a0d..939b7f11 100644 --- a/infini_train/src/core/ccl/cuda/nccl_impl.cc +++ b/infini_train/src/core/ccl/cuda/nccl_impl.cc @@ -12,6 +12,10 @@ #include "infini_train/src/core/ccl/cuda/nccl_common.h" #include "infini_train/src/core/runtime/cuda/cuda_runtime_common.h" +#ifndef NCCL_VERSION_CODE +#define NCCL_VERSION_CODE NCCL_VERSION(NCCL_MAJOR, NCCL_MINOR, NCCL_PATCH) +#endif + namespace infini_train::core::cuda { namespace { @@ -146,6 +150,48 @@ void NcclImpl::ReduceScatter(const void *sendbuff, void *recvbuff, size_t recv_c kNcclReduceOpMap.at(reduce_op), GetNcclComm(comm), GetCudaStream(stream))); } +void NcclImpl::AlltoAll(const void *sendbuff, void *recvbuff, size_t count, DataType dtype, const CclComm *comm, + Stream *stream) const { + auto nccl_comm = GetNcclComm(comm); + auto cuda_stream = GetCudaStream(stream); + CHECK_NE(sendbuff, recvbuff) << "NcclImpl::AlltoAll does not support in-place operation."; + + // NCCL 2.28.3+ provides native host collective ncclAlltoAll with the same contiguous rank-major layout. + // Older NCCL releases do not expose it, so fall back to an equivalent grouped ncclSend/ncclRecv schedule. +#if NCCL_VERSION_CODE >= NCCL_VERSION(2, 28, 3) + NCCL_CHECK(ncclAlltoAll(sendbuff, recvbuff, count, kNcclDtypeMap.at(dtype), nccl_comm, cuda_stream)); +#else + int nranks = 0; + int rank = 0; + NCCL_CHECK(ncclCommCount(nccl_comm, &nranks)); + NCCL_CHECK(ncclCommUserRank(nccl_comm, &rank)); + CHECK_GT(nranks, 0); + CHECK_GE(rank, 0); + CHECK_LT(rank, nranks); + + const size_t chunk_bytes = count * kDataTypeToSize.at(dtype); + auto send_ptr = static_cast(sendbuff); + auto recv_ptr = static_cast(recvbuff); + + if (chunk_bytes > 0) { + CUDA_CHECK(cudaMemcpyAsync(recv_ptr + static_cast(rank) * chunk_bytes, + send_ptr + static_cast(rank) * chunk_bytes, chunk_bytes, + cudaMemcpyDeviceToDevice, cuda_stream)); + } + + NCCL_CHECK(ncclGroupStart()); + for (int peer = 0; peer < nranks; ++peer) { + if (peer == rank) { + continue; + } + const auto offset = static_cast(peer) * chunk_bytes; + NCCL_CHECK(ncclSend(send_ptr + offset, count, kNcclDtypeMap.at(dtype), peer, nccl_comm, cuda_stream)); + NCCL_CHECK(ncclRecv(recv_ptr + offset, count, kNcclDtypeMap.at(dtype), peer, nccl_comm, cuda_stream)); + } + NCCL_CHECK(ncclGroupEnd()); +#endif +} + void NcclImpl::Send(const void *buff, size_t count, DataType dtype, int peer, const CclComm *comm, Stream *stream) const { NCCL_CHECK(ncclSend(buff, count, kNcclDtypeMap.at(dtype), peer, GetNcclComm(comm), GetCudaStream(stream))); diff --git a/infini_train/src/core/ccl/cuda/nccl_impl.h b/infini_train/src/core/ccl/cuda/nccl_impl.h index fca177fd..42d0664a 100644 --- a/infini_train/src/core/ccl/cuda/nccl_impl.h +++ b/infini_train/src/core/ccl/cuda/nccl_impl.h @@ -42,6 +42,9 @@ class NcclImpl final : public CclImpl { nn::parallel::function::ReduceOpType reduce_op, const CclComm *comm, Stream *stream) const override; + void AlltoAll(const void *sendbuff, void *recvbuff, size_t count, DataType dtype, const CclComm *comm, + Stream *stream) const override; + void Send(const void *buff, size_t count, DataType dtype, int peer, const CclComm *comm, Stream *stream) const override; diff --git a/infini_train/src/nn/parallel/parallel_functional.cc b/infini_train/src/nn/parallel/parallel_functional.cc index ffd218d7..0ed4148f 100644 --- a/infini_train/src/nn/parallel/parallel_functional.cc +++ b/infini_train/src/nn/parallel/parallel_functional.cc @@ -40,6 +40,15 @@ std::shared_ptr ReduceScatter(const std::shared_ptr &output, const return pg->ReduceScatter(output, input, reduce_op, async_op); } +std::shared_ptr AlltoAll(const std::shared_ptr &output, const std::shared_ptr &input, + const ProcessGroup *pg, bool async_op) { + auto device = output->GetDevice().type(); + if (pg == nullptr) { + pg = ProcessGroupFactory::Instance(device)->GetDefaultProcessGroup(); + } + return pg->AlltoAll(output, input, async_op); +} + std::vector>> Scatter(const std::vector> &input_tensors, const std::vector &devices, int dim) { std::vector>> output_tensors; diff --git a/infini_train/src/nn/parallel/process_group.cc b/infini_train/src/nn/parallel/process_group.cc index 3c4c4910..7c8faa8b 100644 --- a/infini_train/src/nn/parallel/process_group.cc +++ b/infini_train/src/nn/parallel/process_group.cc @@ -194,6 +194,33 @@ std::shared_ptr ProcessGroup::ReduceScatter(const std::shared_ptr } } +std::shared_ptr ProcessGroup::AlltoAll(const std::shared_ptr &output, + const std::shared_ptr &input, bool async_op) const { + auto device = input->GetDevice(); + CHECK_EQ(device, output->GetDevice()); + CHECK(input->Dtype() == output->Dtype()); + CHECK_EQ(input->NumElements(), output->NumElements()); + CHECK_EQ(input->NumElements() % world_size_, 0) << "AlltoAll input must be evenly divisible by world size"; + core::DeviceGuard guard(device); + auto *compute_stream = runtime_impl_->GetStream(device); + auto *comm_stream = device_stream_map_.at(device.index()); + auto comm = device_comm_map_.at(device.index()); + + auto work = std::make_shared(device, comm); + runtime_impl_->EventRecord(work->ready_event(), compute_stream); + runtime_impl_->StreamWaitEvent(comm_stream, work->ready_event(), 0); + ccl_impl_->AlltoAll(input->DataPtr(), output->DataPtr(), input->NumElements() / world_size_, input->Dtype(), comm, + comm_stream); + runtime_impl_->EventRecord(work->done_event(), comm_stream); + + if (async_op) { + return work; + } else { + work->WaitNonBlocking(); + return nullptr; + } +} + std::shared_ptr ProcessGroup::Send(std::vector> tensors, int dest_rank, bool async_op) const { CHECK_GT(tensors.size(), 0); @@ -248,6 +275,46 @@ std::shared_ptr ProcessGroup::Recv(std::vector> te } } +std::shared_ptr ProcessGroup::BatchSendRecv(const std::vector &ops, bool async_op) const { + CHECK_GT(ops.size(), 0); + CHECK_NOTNULL(ops[0].tensor); + auto device = ops[0].tensor->GetDevice(); + core::DeviceGuard guard(device); + auto *compute_stream = runtime_impl_->GetStream(device); + auto *comm_stream = device_stream_map_.at(device.index()); + auto comm = device_comm_map_.at(device.index()); + + auto work = std::make_shared(device, comm); + runtime_impl_->EventRecord(work->ready_event(), compute_stream); + runtime_impl_->StreamWaitEvent(comm_stream, work->ready_event(), 0); + + { + core::CclGroupGuard ccl_group_guard(backend_); + for (const auto &op : ops) { + CHECK_NOTNULL(op.tensor); + CHECK_EQ(device, op.tensor->GetDevice()); + CHECK_GE(op.peer_rank, 0); + CHECK_LT(op.peer_rank, world_size_); + if (op.type == P2POpType::kSend) { + ccl_impl_->Send(op.tensor->DataPtr(), op.tensor->NumElements(), op.tensor->Dtype(), op.peer_rank, comm, + comm_stream); + } else { + ccl_impl_->Recv(op.tensor->DataPtr(), op.tensor->NumElements(), op.tensor->Dtype(), op.peer_rank, comm, + comm_stream); + } + } + } + + runtime_impl_->EventRecord(work->done_event(), comm_stream); + + if (async_op) { + return work; + } else { + work->WaitNonBlocking(); + return nullptr; + } +} + std::vector> ProcessGroup::BroadCast(const std::vector> &input_tensors) const { std::vector> outputs;