Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions infini_train/include/core/ccl/ccl.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ class CclImpl {
nn::parallel::function::ReduceOpType reduce_op, const CclComm *comm,
Stream *stream) const;

virtual void AlltoAll(const void *sendbuff, void *recvbuff, size_t count, DataType dtype, const CclComm *comm,
Stream *stream) const;

virtual void Send(const void *buff, size_t count, DataType dtype, int peer, const CclComm *comm,
Stream *stream) const;

Expand Down
3 changes: 3 additions & 0 deletions infini_train/include/nn/parallel/parallel_functional.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ std::shared_ptr<Work> AllGather(const std::shared_ptr<Tensor> &output, const std
std::shared_ptr<Work> ReduceScatter(const std::shared_ptr<Tensor> &output, const std::shared_ptr<Tensor> &input,
ReduceOpType reduce_op, const ProcessGroup *pg = nullptr, bool async_op = false);

std::shared_ptr<Work> AlltoAll(const std::shared_ptr<Tensor> &output, const std::shared_ptr<Tensor> &input,
const ProcessGroup *pg = nullptr, bool async_op = false);

std::vector<std::vector<std::shared_ptr<Tensor>>> Scatter(const std::vector<std::shared_ptr<Tensor>> &input_tensors,
const std::vector<Device> &device_ids, int dim);

Expand Down
16 changes: 16 additions & 0 deletions infini_train/include/nn/parallel/process_group.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,17 @@ class Work;

namespace infini_train::nn::parallel {

enum class P2POpType {
kSend,
kRecv,
};

struct P2POp {
P2POpType type;
std::shared_ptr<Tensor> tensor;
int peer_rank;
};

class ProcessGroup {
public:
explicit ProcessGroup(Device::DeviceType backend, const std::string &process_group_name,
Expand All @@ -52,12 +63,17 @@ class ProcessGroup {
function::ReduceOpType reduce_op = function::ReduceOpType::kSum,
bool async_op = false) const;

virtual std::shared_ptr<Work> AlltoAll(const std::shared_ptr<Tensor> &output, const std::shared_ptr<Tensor> &input,
bool async_op = false) const;

virtual std::shared_ptr<Work> Send(std::vector<std::shared_ptr<Tensor>> tensors, int dest_rank,
bool async_op = false) const;

virtual std::shared_ptr<Work> Recv(std::vector<std::shared_ptr<Tensor>> tensors, int src_rank,
bool async_op = false) const;

virtual std::shared_ptr<Work> BatchSendRecv(const std::vector<P2POp> &ops, bool async_op = false) const;

// Legacy communication APIs (Single-stream)
virtual std::vector<std::shared_ptr<Tensor>>
BroadCast(const std::vector<std::shared_ptr<Tensor>> &input_tensors) const;
Expand Down
5 changes: 5 additions & 0 deletions infini_train/src/core/ccl/ccl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,11 @@ void CclImpl::ReduceScatter(const void *sendbuff, void *recvbuff, size_t recv_co
LOG(FATAL) << "CclImpl::ReduceScatter is not implemented.";
}

void CclImpl::AlltoAll(const void *sendbuff, void *recvbuff, size_t count, DataType dtype, const CclComm *comm,
Stream *stream) const {
LOG(FATAL) << "CclImpl::AlltoAll is not implemented.";
}

void CclImpl::Send(const void *buff, size_t count, DataType dtype, int peer, const CclComm *comm,
Stream *stream) const {
LOG(FATAL) << "CclImpl::Send is not implemented.";
Expand Down
46 changes: 46 additions & 0 deletions infini_train/src/core/ccl/cuda/nccl_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@
#include "infini_train/src/core/ccl/cuda/nccl_common.h"
#include "infini_train/src/core/runtime/cuda/cuda_runtime_common.h"

#ifndef NCCL_VERSION_CODE
#define NCCL_VERSION_CODE NCCL_VERSION(NCCL_MAJOR, NCCL_MINOR, NCCL_PATCH)
#endif

namespace infini_train::core::cuda {
namespace {

Expand Down Expand Up @@ -146,6 +150,48 @@ void NcclImpl::ReduceScatter(const void *sendbuff, void *recvbuff, size_t recv_c
kNcclReduceOpMap.at(reduce_op), GetNcclComm(comm), GetCudaStream(stream)));
}

void NcclImpl::AlltoAll(const void *sendbuff, void *recvbuff, size_t count, DataType dtype, const CclComm *comm,
Stream *stream) const {
auto nccl_comm = GetNcclComm(comm);
auto cuda_stream = GetCudaStream(stream);
CHECK_NE(sendbuff, recvbuff) << "NcclImpl::AlltoAll does not support in-place operation.";

// NCCL 2.28.3+ provides native host collective ncclAlltoAll with the same contiguous rank-major layout.
// Older NCCL releases do not expose it, so fall back to an equivalent grouped ncclSend/ncclRecv schedule.
#if NCCL_VERSION_CODE >= NCCL_VERSION(2, 28, 3)
NCCL_CHECK(ncclAlltoAll(sendbuff, recvbuff, count, kNcclDtypeMap.at(dtype), nccl_comm, cuda_stream));
#else
int nranks = 0;
int rank = 0;
NCCL_CHECK(ncclCommCount(nccl_comm, &nranks));
NCCL_CHECK(ncclCommUserRank(nccl_comm, &rank));
CHECK_GT(nranks, 0);
CHECK_GE(rank, 0);
CHECK_LT(rank, nranks);

const size_t chunk_bytes = count * kDataTypeToSize.at(dtype);
auto send_ptr = static_cast<const char *>(sendbuff);
auto recv_ptr = static_cast<char *>(recvbuff);

if (chunk_bytes > 0) {
CUDA_CHECK(cudaMemcpyAsync(recv_ptr + static_cast<size_t>(rank) * chunk_bytes,
send_ptr + static_cast<size_t>(rank) * chunk_bytes, chunk_bytes,
cudaMemcpyDeviceToDevice, cuda_stream));
}

NCCL_CHECK(ncclGroupStart());
for (int peer = 0; peer < nranks; ++peer) {
if (peer == rank) {
continue;
}
const auto offset = static_cast<size_t>(peer) * chunk_bytes;
NCCL_CHECK(ncclSend(send_ptr + offset, count, kNcclDtypeMap.at(dtype), peer, nccl_comm, cuda_stream));
NCCL_CHECK(ncclRecv(recv_ptr + offset, count, kNcclDtypeMap.at(dtype), peer, nccl_comm, cuda_stream));
}
NCCL_CHECK(ncclGroupEnd());
#endif
}

void NcclImpl::Send(const void *buff, size_t count, DataType dtype, int peer, const CclComm *comm,
Stream *stream) const {
NCCL_CHECK(ncclSend(buff, count, kNcclDtypeMap.at(dtype), peer, GetNcclComm(comm), GetCudaStream(stream)));
Expand Down
3 changes: 3 additions & 0 deletions infini_train/src/core/ccl/cuda/nccl_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ class NcclImpl final : public CclImpl {
nn::parallel::function::ReduceOpType reduce_op, const CclComm *comm,
Stream *stream) const override;

void AlltoAll(const void *sendbuff, void *recvbuff, size_t count, DataType dtype, const CclComm *comm,
Stream *stream) const override;

void Send(const void *buff, size_t count, DataType dtype, int peer, const CclComm *comm,
Stream *stream) const override;

Expand Down
9 changes: 9 additions & 0 deletions infini_train/src/nn/parallel/parallel_functional.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,15 @@ std::shared_ptr<Work> ReduceScatter(const std::shared_ptr<Tensor> &output, const
return pg->ReduceScatter(output, input, reduce_op, async_op);
}

std::shared_ptr<Work> AlltoAll(const std::shared_ptr<Tensor> &output, const std::shared_ptr<Tensor> &input,
const ProcessGroup *pg, bool async_op) {
auto device = output->GetDevice().type();
if (pg == nullptr) {
pg = ProcessGroupFactory::Instance(device)->GetDefaultProcessGroup();
}
return pg->AlltoAll(output, input, async_op);
}

std::vector<std::vector<std::shared_ptr<Tensor>>> Scatter(const std::vector<std::shared_ptr<Tensor>> &input_tensors,
const std::vector<Device> &devices, int dim) {
std::vector<std::vector<std::shared_ptr<Tensor>>> output_tensors;
Expand Down
67 changes: 67 additions & 0 deletions infini_train/src/nn/parallel/process_group.cc
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,33 @@ std::shared_ptr<Work> ProcessGroup::ReduceScatter(const std::shared_ptr<Tensor>
}
}

std::shared_ptr<Work> ProcessGroup::AlltoAll(const std::shared_ptr<Tensor> &output,
const std::shared_ptr<Tensor> &input, bool async_op) const {
auto device = input->GetDevice();
CHECK_EQ(device, output->GetDevice());
CHECK(input->Dtype() == output->Dtype());
CHECK_EQ(input->NumElements(), output->NumElements());
CHECK_EQ(input->NumElements() % world_size_, 0) << "AlltoAll input must be evenly divisible by world size";
core::DeviceGuard guard(device);
auto *compute_stream = runtime_impl_->GetStream(device);
auto *comm_stream = device_stream_map_.at(device.index());
auto comm = device_comm_map_.at(device.index());

auto work = std::make_shared<Work>(device, comm);
runtime_impl_->EventRecord(work->ready_event(), compute_stream);
runtime_impl_->StreamWaitEvent(comm_stream, work->ready_event(), 0);
ccl_impl_->AlltoAll(input->DataPtr(), output->DataPtr(), input->NumElements() / world_size_, input->Dtype(), comm,
comm_stream);
runtime_impl_->EventRecord(work->done_event(), comm_stream);

if (async_op) {
return work;
} else {
work->WaitNonBlocking();
return nullptr;
}
}

std::shared_ptr<Work> ProcessGroup::Send(std::vector<std::shared_ptr<Tensor>> tensors, int dest_rank,
bool async_op) const {
CHECK_GT(tensors.size(), 0);
Expand Down Expand Up @@ -248,6 +275,46 @@ std::shared_ptr<Work> ProcessGroup::Recv(std::vector<std::shared_ptr<Tensor>> te
}
}

std::shared_ptr<Work> ProcessGroup::BatchSendRecv(const std::vector<P2POp> &ops, bool async_op) const {
CHECK_GT(ops.size(), 0);
CHECK_NOTNULL(ops[0].tensor);
auto device = ops[0].tensor->GetDevice();
core::DeviceGuard guard(device);
auto *compute_stream = runtime_impl_->GetStream(device);
auto *comm_stream = device_stream_map_.at(device.index());
auto comm = device_comm_map_.at(device.index());

auto work = std::make_shared<Work>(device, comm);
runtime_impl_->EventRecord(work->ready_event(), compute_stream);
runtime_impl_->StreamWaitEvent(comm_stream, work->ready_event(), 0);

{
core::CclGroupGuard ccl_group_guard(backend_);
for (const auto &op : ops) {
CHECK_NOTNULL(op.tensor);
CHECK_EQ(device, op.tensor->GetDevice());
CHECK_GE(op.peer_rank, 0);
CHECK_LT(op.peer_rank, world_size_);
if (op.type == P2POpType::kSend) {
ccl_impl_->Send(op.tensor->DataPtr(), op.tensor->NumElements(), op.tensor->Dtype(), op.peer_rank, comm,
comm_stream);
} else {
ccl_impl_->Recv(op.tensor->DataPtr(), op.tensor->NumElements(), op.tensor->Dtype(), op.peer_rank, comm,
comm_stream);
}
}
}

runtime_impl_->EventRecord(work->done_event(), comm_stream);

if (async_op) {
return work;
} else {
work->WaitNonBlocking();
return nullptr;
}
}

std::vector<std::shared_ptr<Tensor>>
ProcessGroup::BroadCast(const std::vector<std::shared_ptr<Tensor>> &input_tensors) const {
std::vector<std::shared_ptr<Tensor>> outputs;
Expand Down
Loading