From f6396aa0fa952eca3ed90d0527cc65302cdb15e1 Mon Sep 17 00:00:00 2001 From: kilinchange Date: Fri, 12 Dec 2025 09:44:55 +0000 Subject: [PATCH 1/3] feat: refactor ProcessGroup communication ops by decoupling compute and communication streams for tp/pp, with stream synchnonization controlled via 'async_op' --- .../include/nn/parallel/parallel_functional.h | 7 +- .../include/nn/parallel/process_group.h | 24 +- .../include/nn/parallel/reduce_op_type.h | 6 + infini_train/include/nn/parallel/work.h | 2 +- .../src/nn/parallel/parallel_functional.cc | 55 ++-- .../src/nn/parallel/pp/pipeline_schedule.cc | 2 - infini_train/src/nn/parallel/pp/send_recv.cc | 13 +- infini_train/src/nn/parallel/process_group.cc | 262 ++++++++++++------ infini_train/src/nn/parallel/reducer.cc | 4 +- .../src/nn/parallel/tensor_parallel.cc | 12 +- 10 files changed, 239 insertions(+), 148 deletions(-) diff --git a/infini_train/include/nn/parallel/parallel_functional.h b/infini_train/include/nn/parallel/parallel_functional.h index 25dccaf3..a6ad9952 100644 --- a/infini_train/include/nn/parallel/parallel_functional.h +++ b/infini_train/include/nn/parallel/parallel_functional.h @@ -21,13 +21,14 @@ std::vector>> Scatter(const std::vector> Gather(const std::vector>> &outputs, const Device *target_device, int dim); -void AllReduce(const std::shared_ptr &tensor, ReduceOpType reduce_op, const ProcessGroup *pg = nullptr); +void AllReduce(const std::shared_ptr &tensor, ReduceOpType reduce_op, const ProcessGroup *pg = nullptr, + bool async_op = false); void AllGather(const std::shared_ptr &output, const std::shared_ptr &input, - const ProcessGroup *pg = nullptr); + const ProcessGroup *pg = nullptr, bool async_op = false); void ReduceScatter(const std::shared_ptr &output, const std::shared_ptr &input, ReduceOpType reduce_op, - const ProcessGroup *pg = nullptr); + const ProcessGroup *pg = nullptr, bool async_op = false); std::vector>> BroadcastCoalescedReshape(const std::vector> &tensors, diff --git a/infini_train/include/nn/parallel/process_group.h b/infini_train/include/nn/parallel/process_group.h index d9896c85..d7f8c879 100644 --- a/infini_train/include/nn/parallel/process_group.h +++ b/infini_train/include/nn/parallel/process_group.h @@ -37,14 +37,21 @@ class ProcessGroup { int GetGroupRank(int global_rank) const; - // Communication operations - void AllReduce(const std::shared_ptr &tensor, function::ReduceOpType reduce_op) const; + // Asynchronous communication APIs (Compute / Communication stream decoupled) + std::shared_ptr AllReduce(const std::shared_ptr &tensor, + const function::AllreduceOptions &opts) const; - void AllGather(const std::shared_ptr &output, const std::shared_ptr &input) const; + std::shared_ptr AllGather(const std::shared_ptr &output, const std::shared_ptr &input, + bool async_op) const; - void ReduceScatter(const std::shared_ptr &output, const std::shared_ptr &input, - function::ReduceOpType reduce_op) const; + std::shared_ptr ReduceScatter(const std::shared_ptr &output, const std::shared_ptr &input, + const function::AllreduceOptions &opts) const; + std::shared_ptr Send(std::vector> tensors, int dest_rank, bool async_op) const; + + std::shared_ptr Recv(std::vector> tensors, int src_rank, bool async_op) const; + + // Legacy communication APIs (Single-stream) std::vector> BroadCast(const std::vector> &input_tensors) const; std::vector> @@ -56,13 +63,6 @@ class ProcessGroup { std::shared_ptr Gather(const std::vector> &tensors, const Device *destination, int64_t dim) const; - std::vector> NcclSend(std::vector> tensors, int dest_rank) const; - - std::vector> NcclRecv(std::vector> tensors, int src_rank) const; - - // Async communication functions - std::shared_ptr AllReduceAsync(const std::shared_ptr &tensor, function::ReduceOpType reduce_op) const; - private: void InitSingleProcess(const std::vector &ranks); diff --git a/infini_train/include/nn/parallel/reduce_op_type.h b/infini_train/include/nn/parallel/reduce_op_type.h index 0178307e..f3ff6251 100644 --- a/infini_train/include/nn/parallel/reduce_op_type.h +++ b/infini_train/include/nn/parallel/reduce_op_type.h @@ -10,4 +10,10 @@ enum class ReduceOpType : int8_t { kMax, kAvg, }; + +struct AllreduceOptions { + ReduceOpType reduce_op_type = ReduceOpType::kSum; + bool async_op = false; +}; + } // namespace infini_train::nn::parallel::function diff --git a/infini_train/include/nn/parallel/work.h b/infini_train/include/nn/parallel/work.h index 3f304b7a..fca199fa 100644 --- a/infini_train/include/nn/parallel/work.h +++ b/infini_train/include/nn/parallel/work.h @@ -44,7 +44,7 @@ class WorkNccl final : public Work { ~WorkNccl() override; bool WaitBlocking(std::chrono::milliseconds timeout = std::chrono::milliseconds::zero()) override; - bool WaitNonBlocking(); + bool WaitNonBlocking() override; bool IsCompleted() const override; bool IsSuccess() const override; diff --git a/infini_train/src/nn/parallel/parallel_functional.cc b/infini_train/src/nn/parallel/parallel_functional.cc index 41be205e..07826077 100644 --- a/infini_train/src/nn/parallel/parallel_functional.cc +++ b/infini_train/src/nn/parallel/parallel_functional.cc @@ -12,6 +12,33 @@ #include "infini_train/include/tensor.h" namespace infini_train::nn::parallel::function { + +void AllReduce(const std::shared_ptr &tensor, ReduceOpType reduce_op, const ProcessGroup *pg, bool async_op) { + auto device = tensor->GetDevice()->Type(); + if (pg == nullptr) { + pg = ProcessGroupFactory::Instance()->GetDefaultProcessGroup(); + } + pg->AllReduce(tensor, {reduce_op, async_op}); +} + +void AllGather(const std::shared_ptr &output, const std::shared_ptr &input, const ProcessGroup *pg, + bool async_op) { + auto device = output->GetDevice()->Type(); + if (pg == nullptr) { + pg = ProcessGroupFactory::Instance()->GetDefaultProcessGroup(); + } + pg->AllGather(output, input, async_op); +} + +void ReduceScatter(const std::shared_ptr &output, const std::shared_ptr &input, ReduceOpType reduce_op, + const ProcessGroup *pg, bool async_op) { + auto device = output->GetDevice()->Type(); + if (pg == nullptr) { + pg = ProcessGroupFactory::Instance()->GetDefaultProcessGroup(); + } + pg->ReduceScatter(output, input, {reduce_op, async_op}); +} + std::vector>> Scatter(const std::vector> &input_tensors, const std::vector &devices, int dim) { std::vector>> output_tensors; @@ -34,34 +61,6 @@ std::vector> Gather(const std::vector(target_device, dim)->Apply(gather_tensors); } -void AllReduce(const std::shared_ptr &tensor, ReduceOpType reduce_op, const ProcessGroup *pg) { - // TODO(dcj): use no_grad mode later - auto device = tensor->GetDevice()->Type(); - if (pg == nullptr) { - pg = ProcessGroupFactory::Instance()->GetDefaultProcessGroup(); - } - pg->AllReduce(tensor, reduce_op); -} - -void AllGather(const std::shared_ptr &output, const std::shared_ptr &input, const ProcessGroup *pg) { - // TODO(zbl): use no_grad mode later - auto device = output->GetDevice()->Type(); - if (pg == nullptr) { - pg = ProcessGroupFactory::Instance()->GetDefaultProcessGroup(); - } - pg->AllGather(output, input); -} - -void ReduceScatter(const std::shared_ptr &output, const std::shared_ptr &input, ReduceOpType reduce_op, - const ProcessGroup *pg) { - // TODO(zbl): use no_grad mode later - auto device = output->GetDevice()->Type(); - if (pg == nullptr) { - pg = ProcessGroupFactory::Instance()->GetDefaultProcessGroup(); - } - pg->ReduceScatter(output, input, reduce_op); -} - std::vector>> BroadcastCoalescedReshape(const std::vector> &tensors, const std::vector &devices) { diff --git a/infini_train/src/nn/parallel/pp/pipeline_schedule.cc b/infini_train/src/nn/parallel/pp/pipeline_schedule.cc index d5702249..f6081f56 100644 --- a/infini_train/src/nn/parallel/pp/pipeline_schedule.cc +++ b/infini_train/src/nn/parallel/pp/pipeline_schedule.cc @@ -2,14 +2,12 @@ #include "infini_train/include/nn/parallel/pp/pipeline_schedule.h" #include -#include #include #include #include "glog/logging.h" #include "infini_train/include/autocast.h" -#include "infini_train/include/autograd/grad_mode.h" #include "infini_train/include/datatype.h" #include "infini_train/include/device.h" #include "infini_train/include/nn/init.h" diff --git a/infini_train/src/nn/parallel/pp/send_recv.cc b/infini_train/src/nn/parallel/pp/send_recv.cc index 6a24a0e7..bac71f0b 100644 --- a/infini_train/src/nn/parallel/pp/send_recv.cc +++ b/infini_train/src/nn/parallel/pp/send_recv.cc @@ -63,7 +63,7 @@ std::vector> ISend::Forward(const std::vectorGet(GetPipelineParallelProcessGroupName(input_device_->rank().GlobalRank())); - pp_group->NcclSend(input_tensors, peer_rank_); + pp_group->Send(input_tensors, peer_rank_, false); return input_tensors; } @@ -79,14 +79,16 @@ std::vector> ISend::Backward(const std::vectorGet(GetPipelineParallelProcessGroupName(input_device_->rank().GlobalRank())); - return pp_group->NcclRecv(recv_tensors, peer_rank_); + pp_group->Recv(recv_tensors, peer_rank_, false); + + return recv_tensors; } std::vector> IRecv::Forward(const std::vector> &recv_tensors) { CHECK_NOTNULL(src_device_); auto pp_group = ProcessGroupFactory::Instance()->Get(GetPipelineParallelProcessGroupName(src_device_->rank().GlobalRank())); - pp_group->NcclRecv(recv_tensors, peer_rank_); + pp_group->Recv(recv_tensors, peer_rank_, false); return recv_tensors; } @@ -102,7 +104,10 @@ void IRecv::SetupContext(const std::vector> &input_tenso std::vector> IRecv::Backward(const std::vector> &grad_outputs) { auto pp_group = ProcessGroupFactory::Instance()->Get(GetPipelineParallelProcessGroupName(cur_device_->rank().GlobalRank())); - return pp_group->NcclSend(grad_outputs, peer_rank_); + + pp_group->Send(grad_outputs, peer_rank_, false); + + return grad_outputs; } } // namespace functions diff --git a/infini_train/src/nn/parallel/process_group.cc b/infini_train/src/nn/parallel/process_group.cc index bb386a75..bfeddb95 100644 --- a/infini_train/src/nn/parallel/process_group.cc +++ b/infini_train/src/nn/parallel/process_group.cc @@ -185,35 +185,197 @@ void ProcessGroup::InitStreams() { int ProcessGroup::GetGroupRank(int global_rank) const { return global_group_rank_map_.at(global_rank); } -void ProcessGroup::AllReduce(const std::shared_ptr &tensor, function::ReduceOpType reduce_op) const { +std::shared_ptr ProcessGroup::AllReduce(const std::shared_ptr &tensor, + const function::AllreduceOptions &opts) const { void *buffer = tensor->DataPtr(); - const auto *device = dynamic_cast(tensor->GetDevice()); + device->SetDevice(); + auto comm = device_comm_map_.at(device); - device->SetDevice(); + cudaStream_t compute_stream = device->Stream(); + cudaStream_t comm_stream = device_stream_map_.at(device); + + auto work = std::make_shared(device, comm); + + cudaEvent_t ready_event = reinterpret_cast(work->ready_event()); + cudaEvent_t done_event = reinterpret_cast(work->done_event()); + + CUDA_CHECK(cudaEventRecord(ready_event, compute_stream)); + CUDA_CHECK(cudaStreamWaitEvent(comm_stream, ready_event, 0)); + + // Perform NcclAllReduce on comm stream NCCL_CHECK(ncclAllReduce(buffer, buffer, tensor->NumElements(), kNcclDtypeMap.at(tensor->Dtype()), - kNcclReduceOpMap.at(reduce_op), comm, device->Stream())); + kNcclReduceOpMap.at(opts.reduce_op_type), comm, comm_stream)); + + CUDA_CHECK(cudaEventRecord(done_event, comm_stream)); + + if (opts.async_op) { + return std::move(work); + } else { + work->WaitNonBlocking(); + return nullptr; + } } -void ProcessGroup::AllGather(const std::shared_ptr &output, const std::shared_ptr &input) const { +std::shared_ptr ProcessGroup::AllGather(const std::shared_ptr &output, + const std::shared_ptr &input, bool async_op) const { const auto *device = dynamic_cast(input->GetDevice()); auto comm = device_comm_map_.at(device); device->SetDevice(); + + cudaStream_t compute_stream = device->Stream(); + cudaStream_t comm_stream = device_stream_map_.at(device); + + auto work = std::make_shared(device, comm); + + cudaEvent_t ready_event = reinterpret_cast(work->ready_event()); + cudaEvent_t done_event = reinterpret_cast(work->done_event()); + + CUDA_CHECK(cudaEventRecord(ready_event, compute_stream)); + CUDA_CHECK(cudaStreamWaitEvent(comm_stream, ready_event, 0)); + NCCL_CHECK(ncclAllGather(input->DataPtr(), output->DataPtr(), input->NumElements(), - kNcclDtypeMap.at(input->Dtype()), comm, device->Stream())); + kNcclDtypeMap.at(input->Dtype()), comm, comm_stream)); + + CUDA_CHECK(cudaEventRecord(done_event, comm_stream)); + + if (async_op) { + return std::move(work); + } else { + work->WaitNonBlocking(); + return nullptr; + } } -void ProcessGroup::ReduceScatter(const std::shared_ptr &output, const std::shared_ptr &input, - function::ReduceOpType reduce_op) const { +std::shared_ptr ProcessGroup::ReduceScatter(const std::shared_ptr &output, + const std::shared_ptr &input, + const function::AllreduceOptions &opts) const { const auto *device = dynamic_cast(input->GetDevice()); auto comm = device_comm_map_.at(device); device->SetDevice(); + + cudaStream_t compute_stream = device->Stream(); + cudaStream_t comm_stream = device_stream_map_.at(device); + + auto work = std::make_shared(device, comm); + + cudaEvent_t ready_event = reinterpret_cast(work->ready_event()); + cudaEvent_t done_event = reinterpret_cast(work->done_event()); + + CUDA_CHECK(cudaEventRecord(ready_event, compute_stream)); + CUDA_CHECK(cudaStreamWaitEvent(comm_stream, ready_event, 0)); + NCCL_CHECK(ncclReduceScatter(input->DataPtr(), output->DataPtr(), output->NumElements(), - kNcclDtypeMap.at(input->Dtype()), kNcclReduceOpMap.at(reduce_op), comm, - device->Stream())); + kNcclDtypeMap.at(input->Dtype()), kNcclReduceOpMap.at(opts.reduce_op_type), comm, + comm_stream)); + + CUDA_CHECK(cudaEventRecord(done_event, comm_stream)); + + if (opts.async_op) { + return std::move(work); + } else { + work->WaitNonBlocking(); + return nullptr; + } +} + +std::shared_ptr ProcessGroup::Send(std::vector> tensors, int dest_rank, + bool async_op) const { + CHECK_GT(tensors.size(), 0); + const auto *device = dynamic_cast(tensors[0]->GetDevice()); + auto comm = device_comm_map_.at(device); + + device->SetDevice(); + + cudaStream_t compute_stream = device->Stream(); + cudaStream_t comm_stream = device_stream_map_.at(device); + + auto work = std::make_shared(device, comm); + + cudaEvent_t ready_event = reinterpret_cast(work->ready_event()); + cudaEvent_t done_event = reinterpret_cast(work->done_event()); + + CUDA_CHECK(cudaEventRecord(ready_event, compute_stream)); + CUDA_CHECK(cudaStreamWaitEvent(comm_stream, ready_event, 0)); + + for (int i = 0; i < tensors.size(); ++i) { + auto tensor = tensors[i]; + CHECK_NOTNULL(tensor); + + CHECK_EQ(device, tensor->GetDevice()); + + auto dtype = tensor->Dtype(); + auto nccl_dtype = kNcclDtypeMap.at(dtype); + auto count = tensor->NumElements(); + void *buffer = tensor->DataPtr(); + CHECK_NOTNULL(buffer); + + NCCL_CHECK(ncclSend(buffer, count, nccl_dtype, dest_rank, comm, comm_stream)); + } + + CUDA_CHECK(cudaEventRecord(done_event, comm_stream)); + + if (async_op) { + return std::move(work); + } else { + work->WaitNonBlocking(); + return nullptr; + } + + return std::move(work); +} + +std::shared_ptr ProcessGroup::Recv(std::vector> tensors, int src_rank, + bool async_op) const { + CHECK_GT(tensors.size(), 0); + const auto *device = dynamic_cast(tensors[0]->GetDevice()); + auto comm = device_comm_map_.at(device); + + device->SetDevice(); + + cudaStream_t compute_stream = device->Stream(); + cudaStream_t comm_stream = device_stream_map_.at(device); + + auto work = std::make_shared(device, comm); + + cudaEvent_t ready_event = reinterpret_cast(work->ready_event()); + cudaEvent_t done_event = reinterpret_cast(work->done_event()); + + CUDA_CHECK(cudaEventRecord(ready_event, compute_stream)); + CUDA_CHECK(cudaStreamWaitEvent(comm_stream, ready_event, 0)); + + CHECK_NE(src_rank, -1) << "Destination device not found in input tensors's devices"; + + for (int i = 0; i < tensors.size(); ++i) { + auto tensor = tensors[i]; + CHECK_NOTNULL(tensor); + + CHECK_EQ(device, tensor->GetDevice()); + + CHECK_NE(src_rank, -1) << "Source device not found in input devices"; + + auto dtype = tensor->Dtype(); + auto nccl_dtype = kNcclDtypeMap.at(dtype); + auto count = tensor->NumElements(); + void *buffer = tensor->DataPtr(); + CHECK_NOTNULL(buffer); + + NCCL_CHECK(ncclRecv(buffer, count, nccl_dtype, src_rank, comm, compute_stream)); + } + + CUDA_CHECK(cudaEventRecord(done_event, comm_stream)); + + if (async_op) { + return std::move(work); + } else { + work->WaitNonBlocking(); + return nullptr; + } + + return std::move(work); } std::vector> @@ -395,86 +557,6 @@ std::shared_ptr ProcessGroup::Gather(const std::vector> ProcessGroup::NcclSend(std::vector> tensors, - int dest_rank) const { - for (int i = 0; i < tensors.size(); ++i) { - auto tensor = tensors[i]; - CHECK_NOTNULL(tensor); - - auto device = tensor->GetDevice(); - device->SetDevice(); - - cudaStream_t stream = dynamic_cast(device)->Stream(); - ncclComm_t comm = device_comm_map_.at(device); - - CHECK_NE(dest_rank, -1) << "Destination device not found in input tensors's devices"; - - auto dtype = tensor->Dtype(); - auto nccl_dtype = kNcclDtypeMap.at(dtype); - auto count = tensor->NumElements(); - void *buffer = tensor->DataPtr(); - CHECK_NOTNULL(buffer); - - NCCL_CHECK(ncclSend(buffer, count, nccl_dtype, dest_rank, comm, stream)); - } - return tensors; -} - -std::vector> ProcessGroup::NcclRecv(std::vector> tensors, - int src_rank) const { - for (int i = 0; i < tensors.size(); ++i) { - auto tensor = tensors[i]; - CHECK_NOTNULL(tensor); - - auto device = tensor->GetDevice(); - device->SetDevice(); - - cudaStream_t stream = dynamic_cast(device)->Stream(); - ncclComm_t comm = device_comm_map_.at(device); - - CHECK_NE(src_rank, -1) << "Source device not found in input devices"; - - auto dtype = tensor->Dtype(); - auto nccl_dtype = kNcclDtypeMap.at(dtype); - auto count = tensor->NumElements(); - void *buffer = tensor->DataPtr(); - CHECK_NOTNULL(buffer); - - NCCL_CHECK(ncclRecv(buffer, count, nccl_dtype, src_rank, comm, stream)); - } - return tensors; -} - -std::shared_ptr ProcessGroup::AllReduceAsync(const std::shared_ptr &tensor, - function::ReduceOpType reduce_op) const { - void *buffer = tensor->DataPtr(); - const auto *device = dynamic_cast(tensor->GetDevice()); - device->SetDevice(); - - auto comm = device_comm_map_.at(device); - - cudaStream_t compute_stream = device->Stream(); - cudaStream_t comm_stream = device_stream_map_.at(device); - - auto work = std::make_shared(device, comm); - - cudaEvent_t ready_event = reinterpret_cast(work->ready_event()); - cudaEvent_t done_event = reinterpret_cast(work->done_event()); - - CUDA_CHECK(cudaEventRecord(ready_event, compute_stream)); - CUDA_CHECK(cudaStreamWaitEvent(comm_stream, ready_event, 0)); - - // Perform NcclAllReduce on comm stream - NCCL_CHECK(ncclAllReduce(buffer, buffer, tensor->NumElements(), kNcclDtypeMap.at(tensor->Dtype()), - kNcclReduceOpMap.at(reduce_op), comm, comm_stream)); - - CUDA_CHECK(cudaEventRecord(done_event, comm_stream)); - - // Do not let compute stream wait for done event here - return std::move(work); -} - #endif ProcessGroupFactory *ProcessGroupFactory::Instance() { diff --git a/infini_train/src/nn/parallel/reducer.cc b/infini_train/src/nn/parallel/reducer.cc index 32542969..11290c7f 100644 --- a/infini_train/src/nn/parallel/reducer.cc +++ b/infini_train/src/nn/parallel/reducer.cc @@ -4,7 +4,6 @@ #include #include #include -#include #ifdef USE_CUDA #include @@ -13,7 +12,6 @@ #include "glog/logging.h" #include "infini_train/include/autograd/function_hook.h" -#include "infini_train/include/common/cuda/common_cuda.h" #include "infini_train/include/nn/parallel/utils.h" #include "infini_train/include/nn/parallel/work.h" #include "infini_train/include/tensor.h" @@ -421,7 +419,7 @@ void Reducer::FinalizeBucketDense(size_t bucket_index) { // FIXME(zbl): support custom hook later LOG(FATAL) << "Custom hook is not supported now"; } else { - bucket.work = ddp_pg->AllReduceAsync(bucket.contents, function::ReduceOpType::kAvg); + bucket.work = ddp_pg->AllReduce(bucket.contents, {function::ReduceOpType::kAvg, true}); } } diff --git a/infini_train/src/nn/parallel/tensor_parallel.cc b/infini_train/src/nn/parallel/tensor_parallel.cc index bc90a6d2..661cd6f6 100644 --- a/infini_train/src/nn/parallel/tensor_parallel.cc +++ b/infini_train/src/nn/parallel/tensor_parallel.cc @@ -15,6 +15,7 @@ #include "infini_train/include/nn/parallel/global.h" #include "infini_train/include/nn/parallel/parallel_functional.h" #include "infini_train/include/nn/parallel/utils.h" +#include "infini_train/include/nn/parallel/work.h" #include "infini_train/include/tensor.h" namespace infini_train::nn::parallel { @@ -41,7 +42,7 @@ std::shared_ptr GatherAlongFirstDim(const std::shared_ptr &tenso output_shape[0] *= world_size; auto gathered_output = std::make_shared(output_shape, tensor->Dtype(), device); - tp_group->AllGather(gathered_output, tensor); + tp_group->AllGather(gathered_output, tensor, false); return gathered_output; } @@ -61,7 +62,7 @@ std::shared_ptr GatherAlongLastDim(const std::shared_ptr &tensor output_shape[0] *= world_size; auto gathered_output = std::make_shared(output_shape, tensor->Dtype(), device); - tp_group->AllGather(gathered_output, tensor); + tp_group->AllGather(gathered_output, tensor, false); // AllGather gather along dim 0 by default auto output_list = gathered_output->Split(tensor->Dims()[0], 0); @@ -102,7 +103,7 @@ std::shared_ptr Reduce(const std::shared_ptr &tensor) { auto output = std::make_shared(*tensor); - tp_group->AllReduce(output, function::ReduceOpType::kSum); + tp_group->AllReduce(output, {function::ReduceOpType::kSum, false}); return output; } @@ -124,7 +125,8 @@ std::shared_ptr ReduceScatterAlongFirstDim(const std::shared_ptr auto output = std::make_shared(output_shape, tensor->Dtype(), device); - tp_group->ReduceScatter(output, tensor, function::ReduceOpType::kSum); + tp_group->ReduceScatter(output, tensor, {function::ReduceOpType::kSum, false}); + return output; } @@ -463,7 +465,7 @@ VocabParallelCrossEntropy::Forward(const std::vector> &i auto local_max = logits_masked->Max(-1); auto global_max = local_max; if (tp_size > 1) { - tp_group->AllReduce(global_max, function::ReduceOpType::kMax); + tp_group->AllReduce(global_max, {function::ReduceOpType::kMax, false}); } auto shifted = logits_masked->Sub(global_max->Unsqueeze(-1)); From d2eb083b380de576eb7db701cff344d925bdee69 Mon Sep 17 00:00:00 2001 From: kilinchange Date: Mon, 22 Dec 2025 02:30:28 +0000 Subject: [PATCH 2/3] refactor: abstract a ProcessGroup base class, and implement the concrete ProcessGroupNCCL class --- .../include/nn/parallel/parallel_functional.h | 19 ++-- .../include/nn/parallel/process_group.h | 90 ++++++++++++++----- infini_train/include/nn/parallel/work.h | 1 - .../src/nn/parallel/parallel_functional.cc | 17 ++-- infini_train/src/nn/parallel/process_group.cc | 64 ++++++------- 5 files changed, 120 insertions(+), 71 deletions(-) diff --git a/infini_train/include/nn/parallel/parallel_functional.h b/infini_train/include/nn/parallel/parallel_functional.h index a6ad9952..f2559e2d 100644 --- a/infini_train/include/nn/parallel/parallel_functional.h +++ b/infini_train/include/nn/parallel/parallel_functional.h @@ -15,21 +15,22 @@ class Module; } // namespace infini_train namespace infini_train::nn::parallel::function { + +std::shared_ptr AllReduce(const std::shared_ptr &tensor, ReduceOpType reduce_op, + const ProcessGroup *pg = nullptr, bool async_op = false); + +std::shared_ptr AllGather(const std::shared_ptr &output, const std::shared_ptr &input, + const ProcessGroup *pg = nullptr, bool async_op = false); + +std::shared_ptr ReduceScatter(const std::shared_ptr &output, const std::shared_ptr &input, + ReduceOpType reduce_op, const ProcessGroup *pg = nullptr, bool async_op = false); + std::vector>> Scatter(const std::vector> &input_tensors, const std::vector &device_ids, int dim); std::vector> Gather(const std::vector>> &outputs, const Device *target_device, int dim); -void AllReduce(const std::shared_ptr &tensor, ReduceOpType reduce_op, const ProcessGroup *pg = nullptr, - bool async_op = false); - -void AllGather(const std::shared_ptr &output, const std::shared_ptr &input, - const ProcessGroup *pg = nullptr, bool async_op = false); - -void ReduceScatter(const std::shared_ptr &output, const std::shared_ptr &input, ReduceOpType reduce_op, - const ProcessGroup *pg = nullptr, bool async_op = false); - std::vector>> BroadcastCoalescedReshape(const std::vector> &tensors, const std::vector &devices); diff --git a/infini_train/include/nn/parallel/process_group.h b/infini_train/include/nn/parallel/process_group.h index d7f8c879..12d53c59 100644 --- a/infini_train/include/nn/parallel/process_group.h +++ b/infini_train/include/nn/parallel/process_group.h @@ -28,40 +28,96 @@ class Work; namespace infini_train::nn::parallel { -#ifdef USE_NCCL class ProcessGroup { public: - explicit ProcessGroup(const std::string &process_group_name, const std::vector &device_indices); + virtual int GetGroupRank(int global_rank) const; + + // Asynchronous communication APIs (Compute / Communication stream decoupled) + virtual std::shared_ptr AllReduce(const std::shared_ptr &tensor, + const function::AllreduceOptions &opts) const + = 0; + + virtual std::shared_ptr AllGather(const std::shared_ptr &output, const std::shared_ptr &input, + bool async_op) const + = 0; + + virtual std::shared_ptr ReduceScatter(const std::shared_ptr &output, + const std::shared_ptr &input, + const function::AllreduceOptions &opts) const + = 0; + + virtual std::shared_ptr Send(std::vector> tensors, int dest_rank, bool async_op) const + = 0; + + virtual std::shared_ptr Recv(std::vector> tensors, int src_rank, bool async_op) const + = 0; + + // Legacy communication APIs (Single-stream) + virtual std::vector> + BroadCast(const std::vector> &input_tensors) const = 0; - ~ProcessGroup(); + virtual std::vector> + ReduceAddCoalesced(const std::vector>> &grads, const Device *destination) const + = 0; - int GetGroupRank(int global_rank) const; + virtual std::vector> Scatter(const std::shared_ptr &tensor, + std::vector devices, int64_t dim) const + = 0; + + virtual std::shared_ptr Gather(const std::vector> &tensors, + const Device *destination, int64_t dim) const + = 0; + +protected: + ProcessGroup(int world_size, const std::string &name); + + std::vector devices_; + + std::unordered_map global_group_rank_map_; // global_rank : group_rank + + int world_size_ = 0; + + const std::string name_ = ""; + + bool is_main_process_ = false; +}; + +#ifdef USE_NCCL +class ProcessGroupNCCL final : public ProcessGroup { +public: + explicit ProcessGroupNCCL(const std::string &process_group_name, const std::vector &device_indices); + + ~ProcessGroupNCCL(); // Asynchronous communication APIs (Compute / Communication stream decoupled) std::shared_ptr AllReduce(const std::shared_ptr &tensor, - const function::AllreduceOptions &opts) const; + const function::AllreduceOptions &opts) const override; std::shared_ptr AllGather(const std::shared_ptr &output, const std::shared_ptr &input, - bool async_op) const; + bool async_op) const override; std::shared_ptr ReduceScatter(const std::shared_ptr &output, const std::shared_ptr &input, - const function::AllreduceOptions &opts) const; + const function::AllreduceOptions &opts) const override; - std::shared_ptr Send(std::vector> tensors, int dest_rank, bool async_op) const; + std::shared_ptr Send(std::vector> tensors, int dest_rank, + bool async_op) const override; - std::shared_ptr Recv(std::vector> tensors, int src_rank, bool async_op) const; + std::shared_ptr Recv(std::vector> tensors, int src_rank, + bool async_op) const override; // Legacy communication APIs (Single-stream) - std::vector> BroadCast(const std::vector> &input_tensors) const; + std::vector> + BroadCast(const std::vector> &input_tensors) const override; std::vector> - ReduceAddCoalesced(const std::vector>> &grads, const Device *destination) const; + ReduceAddCoalesced(const std::vector>> &grads, + const Device *destination) const override; std::vector> Scatter(const std::shared_ptr &tensor, - std::vector devices, int64_t dim) const; + std::vector devices, int64_t dim) const override; std::shared_ptr Gather(const std::vector> &tensors, const Device *destination, - int64_t dim) const; + int64_t dim) const override; private: void InitSingleProcess(const std::vector &ranks); @@ -73,17 +129,9 @@ class ProcessGroup { private: std::vector comms_; std::vector comm_streams_; - std::vector devices_; std::unordered_map device_comm_map_; std::unordered_map device_stream_map_; - std::unordered_map global_group_rank_map_; // global_rank : group_rank - - int world_size_ = 0; - - const std::string name_ = ""; - - bool is_main_process_ = false; }; #endif diff --git a/infini_train/include/nn/parallel/work.h b/infini_train/include/nn/parallel/work.h index fca199fa..1e11cc02 100644 --- a/infini_train/include/nn/parallel/work.h +++ b/infini_train/include/nn/parallel/work.h @@ -3,7 +3,6 @@ #include #include #include -#include #include #ifdef USE_CUDA diff --git a/infini_train/src/nn/parallel/parallel_functional.cc b/infini_train/src/nn/parallel/parallel_functional.cc index 07826077..282f380a 100644 --- a/infini_train/src/nn/parallel/parallel_functional.cc +++ b/infini_train/src/nn/parallel/parallel_functional.cc @@ -13,30 +13,31 @@ namespace infini_train::nn::parallel::function { -void AllReduce(const std::shared_ptr &tensor, ReduceOpType reduce_op, const ProcessGroup *pg, bool async_op) { +std::shared_ptr AllReduce(const std::shared_ptr &tensor, ReduceOpType reduce_op, const ProcessGroup *pg, + bool async_op) { auto device = tensor->GetDevice()->Type(); if (pg == nullptr) { pg = ProcessGroupFactory::Instance()->GetDefaultProcessGroup(); } - pg->AllReduce(tensor, {reduce_op, async_op}); + return pg->AllReduce(tensor, {reduce_op, async_op}); } -void AllGather(const std::shared_ptr &output, const std::shared_ptr &input, const ProcessGroup *pg, - bool async_op) { +std::shared_ptr AllGather(const std::shared_ptr &output, const std::shared_ptr &input, + const ProcessGroup *pg, bool async_op) { auto device = output->GetDevice()->Type(); if (pg == nullptr) { pg = ProcessGroupFactory::Instance()->GetDefaultProcessGroup(); } - pg->AllGather(output, input, async_op); + return pg->AllGather(output, input, async_op); } -void ReduceScatter(const std::shared_ptr &output, const std::shared_ptr &input, ReduceOpType reduce_op, - const ProcessGroup *pg, bool async_op) { +std::shared_ptr ReduceScatter(const std::shared_ptr &output, const std::shared_ptr &input, + ReduceOpType reduce_op, const ProcessGroup *pg, bool async_op) { auto device = output->GetDevice()->Type(); if (pg == nullptr) { pg = ProcessGroupFactory::Instance()->GetDefaultProcessGroup(); } - pg->ReduceScatter(output, input, {reduce_op, async_op}); + return pg->ReduceScatter(output, input, {reduce_op, async_op}); } std::vector>> Scatter(const std::vector> &input_tensors, diff --git a/infini_train/src/nn/parallel/process_group.cc b/infini_train/src/nn/parallel/process_group.cc index bfeddb95..a52eda18 100644 --- a/infini_train/src/nn/parallel/process_group.cc +++ b/infini_train/src/nn/parallel/process_group.cc @@ -86,9 +86,13 @@ void CleanupNcclIdFile(const std::string &pg_name) { namespace infini_train::nn::parallel { +int ProcessGroup::GetGroupRank(int global_rank) const { return global_group_rank_map_.at(global_rank); } + +ProcessGroup::ProcessGroup(int world_size, const std::string &name) : world_size_(world_size), name_(name) {} + #ifdef USE_NCCL -ProcessGroup::ProcessGroup(const std::string &process_group_name, const std::vector &ranks) - : world_size_(ranks.size()), name_(process_group_name) { +ProcessGroupNCCL::ProcessGroupNCCL(const std::string &process_group_name, const std::vector &ranks) + : ProcessGroup(ranks.size(), process_group_name) { int current_device = -1; CUDA_CHECK(cudaGetDevice(¤t_device)); @@ -102,7 +106,7 @@ ProcessGroup::ProcessGroup(const std::string &process_group_name, const std::vec CUDA_CHECK(cudaSetDevice(current_device)); } -ProcessGroup::~ProcessGroup() { +ProcessGroupNCCL::~ProcessGroupNCCL() { if (is_main_process_) { CleanupNcclIdFile(name_); } @@ -119,7 +123,7 @@ ProcessGroup::~ProcessGroup() { } } -void ProcessGroup::InitSingleProcess(const std::vector &ranks) { +void ProcessGroupNCCL::InitSingleProcess(const std::vector &ranks) { comms_.resize(world_size_); NCCL_CHECK(ncclCommInitAll(comms_.data(), world_size_, ranks.data())); @@ -131,7 +135,7 @@ void ProcessGroup::InitSingleProcess(const std::vector &ranks) { } } -void ProcessGroup::InitMultiProcess(const std::vector &ranks) { +void ProcessGroupNCCL::InitMultiProcess(const std::vector &ranks) { int n_threads = global::GetNthreadPerProc(); int global_proc_rank = global::GetGlobalProcRank(); int lower_rank = global_proc_rank * n_threads; @@ -170,7 +174,7 @@ void ProcessGroup::InitMultiProcess(const std::vector &ranks) { NCCL_CHECK(ncclGroupEnd()); } -void ProcessGroup::InitStreams() { +void ProcessGroupNCCL::InitStreams() { int device_size = devices_.size(); comm_streams_.resize(device_size); @@ -183,10 +187,8 @@ void ProcessGroup::InitStreams() { } } -int ProcessGroup::GetGroupRank(int global_rank) const { return global_group_rank_map_.at(global_rank); } - -std::shared_ptr ProcessGroup::AllReduce(const std::shared_ptr &tensor, - const function::AllreduceOptions &opts) const { +std::shared_ptr ProcessGroupNCCL::AllReduce(const std::shared_ptr &tensor, + const function::AllreduceOptions &opts) const { void *buffer = tensor->DataPtr(); const auto *device = dynamic_cast(tensor->GetDevice()); device->SetDevice(); @@ -218,8 +220,8 @@ std::shared_ptr ProcessGroup::AllReduce(const std::shared_ptr &ten } } -std::shared_ptr ProcessGroup::AllGather(const std::shared_ptr &output, - const std::shared_ptr &input, bool async_op) const { +std::shared_ptr ProcessGroupNCCL::AllGather(const std::shared_ptr &output, + const std::shared_ptr &input, bool async_op) const { const auto *device = dynamic_cast(input->GetDevice()); auto comm = device_comm_map_.at(device); @@ -249,9 +251,9 @@ std::shared_ptr ProcessGroup::AllGather(const std::shared_ptr &out } } -std::shared_ptr ProcessGroup::ReduceScatter(const std::shared_ptr &output, - const std::shared_ptr &input, - const function::AllreduceOptions &opts) const { +std::shared_ptr ProcessGroupNCCL::ReduceScatter(const std::shared_ptr &output, + const std::shared_ptr &input, + const function::AllreduceOptions &opts) const { const auto *device = dynamic_cast(input->GetDevice()); auto comm = device_comm_map_.at(device); @@ -282,8 +284,8 @@ std::shared_ptr ProcessGroup::ReduceScatter(const std::shared_ptr } } -std::shared_ptr ProcessGroup::Send(std::vector> tensors, int dest_rank, - bool async_op) const { +std::shared_ptr ProcessGroupNCCL::Send(std::vector> tensors, int dest_rank, + bool async_op) const { CHECK_GT(tensors.size(), 0); const auto *device = dynamic_cast(tensors[0]->GetDevice()); auto comm = device_comm_map_.at(device); @@ -328,8 +330,8 @@ std::shared_ptr ProcessGroup::Send(std::vector> te return std::move(work); } -std::shared_ptr ProcessGroup::Recv(std::vector> tensors, int src_rank, - bool async_op) const { +std::shared_ptr ProcessGroupNCCL::Recv(std::vector> tensors, int src_rank, + bool async_op) const { CHECK_GT(tensors.size(), 0); const auto *device = dynamic_cast(tensors[0]->GetDevice()); auto comm = device_comm_map_.at(device); @@ -347,16 +349,12 @@ std::shared_ptr ProcessGroup::Recv(std::vector> te CUDA_CHECK(cudaEventRecord(ready_event, compute_stream)); CUDA_CHECK(cudaStreamWaitEvent(comm_stream, ready_event, 0)); - CHECK_NE(src_rank, -1) << "Destination device not found in input tensors's devices"; - for (int i = 0; i < tensors.size(); ++i) { auto tensor = tensors[i]; CHECK_NOTNULL(tensor); CHECK_EQ(device, tensor->GetDevice()); - CHECK_NE(src_rank, -1) << "Source device not found in input devices"; - auto dtype = tensor->Dtype(); auto nccl_dtype = kNcclDtypeMap.at(dtype); auto count = tensor->NumElements(); @@ -379,7 +377,7 @@ std::shared_ptr ProcessGroup::Recv(std::vector> te } std::vector> -ProcessGroup::BroadCast(const std::vector> &input_tensors) const { +ProcessGroupNCCL::BroadCast(const std::vector> &input_tensors) const { std::vector> outputs; std::vector streams; std::vector comms; @@ -425,8 +423,8 @@ ProcessGroup::BroadCast(const std::vector> &input_tensor } std::vector> -ProcessGroup::ReduceAddCoalesced(const std::vector>> &grads, - const Device *destination) const { +ProcessGroupNCCL::ReduceAddCoalesced(const std::vector>> &grads, + const Device *destination) const { // grads: [devices, tensors] std::vector> outputs; std::vector streams; @@ -470,8 +468,8 @@ ProcessGroup::ReduceAddCoalesced(const std::vector> ProcessGroup::Scatter(const std::shared_ptr &tensor, - std::vector devices, int64_t dim) const { +std::vector> ProcessGroupNCCL::Scatter(const std::shared_ptr &tensor, + std::vector devices, int64_t dim) const { std::vector> outputs; std::vector> split_tensors = tensor->Split(tensor->Dims()[dim] / devices.size(), dim); std::vector streams; @@ -505,8 +503,8 @@ std::vector> ProcessGroup::Scatter(const std::shared_ptr return outputs; } -std::shared_ptr ProcessGroup::Gather(const std::vector> &tensors, - const Device *destination, int64_t dim) const { +std::shared_ptr ProcessGroupNCCL::Gather(const std::vector> &tensors, + const Device *destination, int64_t dim) const { std::vector> outouts; int64_t num_devices = tensors.size(); auto dtype = tensors[0]->Dtype(); @@ -574,11 +572,13 @@ ProcessGroupFactory *ProcessGroupFactory::Instance() { const ProcessGroup *ProcessGroupFactory::GetOrCreate(const std::string &name, int comm_size) { std::vector device_indices(comm_size); std::iota(device_indices.begin(), device_indices.end(), 0); - return GetOrCreate(name, [&]() { return std::make_unique(name, device_indices); }); + // TODO(dcj): create device-specific ProcessGroup based on the registered device later + return GetOrCreate(name, [&]() { return std::make_unique(name, device_indices); }); } const ProcessGroup *ProcessGroupFactory::GetOrCreate(const std::string &name, const std::vector &device_indices) { - return GetOrCreate(name, [&]() { return std::make_unique(name, device_indices); }); + // TODO(dcj): create device-specific ProcessGroup based on the registered device later + return GetOrCreate(name, [&]() { return std::make_unique(name, device_indices); }); } const ProcessGroup *ProcessGroupFactory::Get(const std::string &name) const { From 6dd760e52716f93081984937e22b9852ae52a53c Mon Sep 17 00:00:00 2001 From: kilinchange Date: Mon, 22 Dec 2025 07:15:34 +0000 Subject: [PATCH 3/3] fix: resolve review comments --- .../include/nn/parallel/process_group.h | 23 +++++++++++-------- .../include/nn/parallel/reduce_op_type.h | 5 ---- .../src/nn/parallel/parallel_functional.cc | 4 ++-- infini_train/src/nn/parallel/process_group.cc | 13 +++++------ infini_train/src/nn/parallel/reducer.cc | 2 +- .../src/nn/parallel/tensor_parallel.cc | 6 ++--- 6 files changed, 25 insertions(+), 28 deletions(-) diff --git a/infini_train/include/nn/parallel/process_group.h b/infini_train/include/nn/parallel/process_group.h index 12d53c59..bf852a37 100644 --- a/infini_train/include/nn/parallel/process_group.h +++ b/infini_train/include/nn/parallel/process_group.h @@ -34,22 +34,25 @@ class ProcessGroup { // Asynchronous communication APIs (Compute / Communication stream decoupled) virtual std::shared_ptr AllReduce(const std::shared_ptr &tensor, - const function::AllreduceOptions &opts) const + function::ReduceOpType reduce_op = function::ReduceOpType::kSum, + bool async_op = false) const = 0; virtual std::shared_ptr AllGather(const std::shared_ptr &output, const std::shared_ptr &input, - bool async_op) const + bool async_op = false) const = 0; - virtual std::shared_ptr ReduceScatter(const std::shared_ptr &output, - const std::shared_ptr &input, - const function::AllreduceOptions &opts) const + virtual std::shared_ptr + ReduceScatter(const std::shared_ptr &output, const std::shared_ptr &input, + function::ReduceOpType reduce_op = function::ReduceOpType::kSum, bool async_op = false) const = 0; - virtual std::shared_ptr Send(std::vector> tensors, int dest_rank, bool async_op) const + virtual std::shared_ptr Send(std::vector> tensors, int dest_rank, + bool async_op = false) const = 0; - virtual std::shared_ptr Recv(std::vector> tensors, int src_rank, bool async_op) const + virtual std::shared_ptr Recv(std::vector> tensors, int src_rank, + bool async_op = false) const = 0; // Legacy communication APIs (Single-stream) @@ -90,14 +93,14 @@ class ProcessGroupNCCL final : public ProcessGroup { ~ProcessGroupNCCL(); // Asynchronous communication APIs (Compute / Communication stream decoupled) - std::shared_ptr AllReduce(const std::shared_ptr &tensor, - const function::AllreduceOptions &opts) const override; + std::shared_ptr AllReduce(const std::shared_ptr &tensor, function::ReduceOpType reduce_op, + bool async_op) const override; std::shared_ptr AllGather(const std::shared_ptr &output, const std::shared_ptr &input, bool async_op) const override; std::shared_ptr ReduceScatter(const std::shared_ptr &output, const std::shared_ptr &input, - const function::AllreduceOptions &opts) const override; + function::ReduceOpType reduce_op, bool async_op) const override; std::shared_ptr Send(std::vector> tensors, int dest_rank, bool async_op) const override; diff --git a/infini_train/include/nn/parallel/reduce_op_type.h b/infini_train/include/nn/parallel/reduce_op_type.h index f3ff6251..d366a1f0 100644 --- a/infini_train/include/nn/parallel/reduce_op_type.h +++ b/infini_train/include/nn/parallel/reduce_op_type.h @@ -11,9 +11,4 @@ enum class ReduceOpType : int8_t { kAvg, }; -struct AllreduceOptions { - ReduceOpType reduce_op_type = ReduceOpType::kSum; - bool async_op = false; -}; - } // namespace infini_train::nn::parallel::function diff --git a/infini_train/src/nn/parallel/parallel_functional.cc b/infini_train/src/nn/parallel/parallel_functional.cc index 282f380a..50408949 100644 --- a/infini_train/src/nn/parallel/parallel_functional.cc +++ b/infini_train/src/nn/parallel/parallel_functional.cc @@ -19,7 +19,7 @@ std::shared_ptr AllReduce(const std::shared_ptr &tensor, ReduceOpT if (pg == nullptr) { pg = ProcessGroupFactory::Instance()->GetDefaultProcessGroup(); } - return pg->AllReduce(tensor, {reduce_op, async_op}); + return pg->AllReduce(tensor, reduce_op, async_op); } std::shared_ptr AllGather(const std::shared_ptr &output, const std::shared_ptr &input, @@ -37,7 +37,7 @@ std::shared_ptr ReduceScatter(const std::shared_ptr &output, const if (pg == nullptr) { pg = ProcessGroupFactory::Instance()->GetDefaultProcessGroup(); } - return pg->ReduceScatter(output, input, {reduce_op, async_op}); + return pg->ReduceScatter(output, input, reduce_op, async_op); } std::vector>> Scatter(const std::vector> &input_tensors, diff --git a/infini_train/src/nn/parallel/process_group.cc b/infini_train/src/nn/parallel/process_group.cc index a52eda18..50a75d48 100644 --- a/infini_train/src/nn/parallel/process_group.cc +++ b/infini_train/src/nn/parallel/process_group.cc @@ -188,7 +188,7 @@ void ProcessGroupNCCL::InitStreams() { } std::shared_ptr ProcessGroupNCCL::AllReduce(const std::shared_ptr &tensor, - const function::AllreduceOptions &opts) const { + function::ReduceOpType reduce_op, bool async_op) const { void *buffer = tensor->DataPtr(); const auto *device = dynamic_cast(tensor->GetDevice()); device->SetDevice(); @@ -208,11 +208,11 @@ std::shared_ptr ProcessGroupNCCL::AllReduce(const std::shared_ptr // Perform NcclAllReduce on comm stream NCCL_CHECK(ncclAllReduce(buffer, buffer, tensor->NumElements(), kNcclDtypeMap.at(tensor->Dtype()), - kNcclReduceOpMap.at(opts.reduce_op_type), comm, comm_stream)); + kNcclReduceOpMap.at(reduce_op), comm, comm_stream)); CUDA_CHECK(cudaEventRecord(done_event, comm_stream)); - if (opts.async_op) { + if (async_op) { return std::move(work); } else { work->WaitNonBlocking(); @@ -253,7 +253,7 @@ std::shared_ptr ProcessGroupNCCL::AllGather(const std::shared_ptr std::shared_ptr ProcessGroupNCCL::ReduceScatter(const std::shared_ptr &output, const std::shared_ptr &input, - const function::AllreduceOptions &opts) const { + function::ReduceOpType reduce_op, bool async_op) const { const auto *device = dynamic_cast(input->GetDevice()); auto comm = device_comm_map_.at(device); @@ -271,12 +271,11 @@ std::shared_ptr ProcessGroupNCCL::ReduceScatter(const std::shared_ptrDataPtr(), output->DataPtr(), output->NumElements(), - kNcclDtypeMap.at(input->Dtype()), kNcclReduceOpMap.at(opts.reduce_op_type), comm, - comm_stream)); + kNcclDtypeMap.at(input->Dtype()), kNcclReduceOpMap.at(reduce_op), comm, comm_stream)); CUDA_CHECK(cudaEventRecord(done_event, comm_stream)); - if (opts.async_op) { + if (async_op) { return std::move(work); } else { work->WaitNonBlocking(); diff --git a/infini_train/src/nn/parallel/reducer.cc b/infini_train/src/nn/parallel/reducer.cc index 11290c7f..0cdd7703 100644 --- a/infini_train/src/nn/parallel/reducer.cc +++ b/infini_train/src/nn/parallel/reducer.cc @@ -419,7 +419,7 @@ void Reducer::FinalizeBucketDense(size_t bucket_index) { // FIXME(zbl): support custom hook later LOG(FATAL) << "Custom hook is not supported now"; } else { - bucket.work = ddp_pg->AllReduce(bucket.contents, {function::ReduceOpType::kAvg, true}); + bucket.work = ddp_pg->AllReduce(bucket.contents, function::ReduceOpType::kAvg, true); } } diff --git a/infini_train/src/nn/parallel/tensor_parallel.cc b/infini_train/src/nn/parallel/tensor_parallel.cc index 661cd6f6..b91028df 100644 --- a/infini_train/src/nn/parallel/tensor_parallel.cc +++ b/infini_train/src/nn/parallel/tensor_parallel.cc @@ -103,7 +103,7 @@ std::shared_ptr Reduce(const std::shared_ptr &tensor) { auto output = std::make_shared(*tensor); - tp_group->AllReduce(output, {function::ReduceOpType::kSum, false}); + tp_group->AllReduce(output, function::ReduceOpType::kSum, false); return output; } @@ -125,7 +125,7 @@ std::shared_ptr ReduceScatterAlongFirstDim(const std::shared_ptr auto output = std::make_shared(output_shape, tensor->Dtype(), device); - tp_group->ReduceScatter(output, tensor, {function::ReduceOpType::kSum, false}); + tp_group->ReduceScatter(output, tensor, function::ReduceOpType::kSum, false); return output; } @@ -465,7 +465,7 @@ VocabParallelCrossEntropy::Forward(const std::vector> &i auto local_max = logits_masked->Max(-1); auto global_max = local_max; if (tp_size > 1) { - tp_group->AllReduce(global_max, {function::ReduceOpType::kMax, false}); + tp_group->AllReduce(global_max, function::ReduceOpType::kMax, false); } auto shifted = logits_masked->Sub(global_max->Unsqueeze(-1));