diff --git a/infini_train/include/nn/parallel/parallel_functional.h b/infini_train/include/nn/parallel/parallel_functional.h index 25dccaf3..f2559e2d 100644 --- a/infini_train/include/nn/parallel/parallel_functional.h +++ b/infini_train/include/nn/parallel/parallel_functional.h @@ -15,20 +15,22 @@ class Module; } // namespace infini_train namespace infini_train::nn::parallel::function { + +std::shared_ptr AllReduce(const std::shared_ptr &tensor, ReduceOpType reduce_op, + const ProcessGroup *pg = nullptr, bool async_op = false); + +std::shared_ptr AllGather(const std::shared_ptr &output, const std::shared_ptr &input, + const ProcessGroup *pg = nullptr, bool async_op = false); + +std::shared_ptr ReduceScatter(const std::shared_ptr &output, const std::shared_ptr &input, + ReduceOpType reduce_op, const ProcessGroup *pg = nullptr, bool async_op = false); + std::vector>> Scatter(const std::vector> &input_tensors, const std::vector &device_ids, int dim); std::vector> Gather(const std::vector>> &outputs, const Device *target_device, int dim); -void AllReduce(const std::shared_ptr &tensor, ReduceOpType reduce_op, const ProcessGroup *pg = nullptr); - -void AllGather(const std::shared_ptr &output, const std::shared_ptr &input, - const ProcessGroup *pg = nullptr); - -void ReduceScatter(const std::shared_ptr &output, const std::shared_ptr &input, ReduceOpType reduce_op, - const ProcessGroup *pg = nullptr); - std::vector>> BroadcastCoalescedReshape(const std::vector> &tensors, const std::vector &devices); diff --git a/infini_train/include/nn/parallel/process_group.h b/infini_train/include/nn/parallel/process_group.h index d9896c85..bf852a37 100644 --- a/infini_train/include/nn/parallel/process_group.h +++ b/infini_train/include/nn/parallel/process_group.h @@ -28,40 +28,99 @@ class Work; namespace infini_train::nn::parallel { -#ifdef USE_NCCL class ProcessGroup { public: - explicit ProcessGroup(const std::string &process_group_name, const std::vector &device_indices); + virtual int GetGroupRank(int global_rank) const; - ~ProcessGroup(); + // Asynchronous communication APIs (Compute / Communication stream decoupled) + virtual std::shared_ptr AllReduce(const std::shared_ptr &tensor, + function::ReduceOpType reduce_op = function::ReduceOpType::kSum, + bool async_op = false) const + = 0; - int GetGroupRank(int global_rank) const; + virtual std::shared_ptr AllGather(const std::shared_ptr &output, const std::shared_ptr &input, + bool async_op = false) const + = 0; - // Communication operations - void AllReduce(const std::shared_ptr &tensor, function::ReduceOpType reduce_op) const; + virtual std::shared_ptr + ReduceScatter(const std::shared_ptr &output, const std::shared_ptr &input, + function::ReduceOpType reduce_op = function::ReduceOpType::kSum, bool async_op = false) const + = 0; - void AllGather(const std::shared_ptr &output, const std::shared_ptr &input) const; + virtual std::shared_ptr Send(std::vector> tensors, int dest_rank, + bool async_op = false) const + = 0; - void ReduceScatter(const std::shared_ptr &output, const std::shared_ptr &input, - function::ReduceOpType reduce_op) const; + virtual std::shared_ptr Recv(std::vector> tensors, int src_rank, + bool async_op = false) const + = 0; - std::vector> BroadCast(const std::vector> &input_tensors) const; + // Legacy communication APIs (Single-stream) + virtual std::vector> + BroadCast(const std::vector> &input_tensors) const = 0; - std::vector> - ReduceAddCoalesced(const std::vector>> &grads, const Device *destination) const; + virtual std::vector> + ReduceAddCoalesced(const std::vector>> &grads, const Device *destination) const + = 0; - std::vector> Scatter(const std::shared_ptr &tensor, - std::vector devices, int64_t dim) const; + virtual std::vector> Scatter(const std::shared_ptr &tensor, + std::vector devices, int64_t dim) const + = 0; - std::shared_ptr Gather(const std::vector> &tensors, const Device *destination, - int64_t dim) const; + virtual std::shared_ptr Gather(const std::vector> &tensors, + const Device *destination, int64_t dim) const + = 0; + +protected: + ProcessGroup(int world_size, const std::string &name); + + std::vector devices_; + + std::unordered_map global_group_rank_map_; // global_rank : group_rank + + int world_size_ = 0; + + const std::string name_ = ""; + + bool is_main_process_ = false; +}; + +#ifdef USE_NCCL +class ProcessGroupNCCL final : public ProcessGroup { +public: + explicit ProcessGroupNCCL(const std::string &process_group_name, const std::vector &device_indices); - std::vector> NcclSend(std::vector> tensors, int dest_rank) const; + ~ProcessGroupNCCL(); - std::vector> NcclRecv(std::vector> tensors, int src_rank) const; + // Asynchronous communication APIs (Compute / Communication stream decoupled) + std::shared_ptr AllReduce(const std::shared_ptr &tensor, function::ReduceOpType reduce_op, + bool async_op) const override; - // Async communication functions - std::shared_ptr AllReduceAsync(const std::shared_ptr &tensor, function::ReduceOpType reduce_op) const; + std::shared_ptr AllGather(const std::shared_ptr &output, const std::shared_ptr &input, + bool async_op) const override; + + std::shared_ptr ReduceScatter(const std::shared_ptr &output, const std::shared_ptr &input, + function::ReduceOpType reduce_op, bool async_op) const override; + + std::shared_ptr Send(std::vector> tensors, int dest_rank, + bool async_op) const override; + + std::shared_ptr Recv(std::vector> tensors, int src_rank, + bool async_op) const override; + + // Legacy communication APIs (Single-stream) + std::vector> + BroadCast(const std::vector> &input_tensors) const override; + + std::vector> + ReduceAddCoalesced(const std::vector>> &grads, + const Device *destination) const override; + + std::vector> Scatter(const std::shared_ptr &tensor, + std::vector devices, int64_t dim) const override; + + std::shared_ptr Gather(const std::vector> &tensors, const Device *destination, + int64_t dim) const override; private: void InitSingleProcess(const std::vector &ranks); @@ -73,17 +132,9 @@ class ProcessGroup { private: std::vector comms_; std::vector comm_streams_; - std::vector devices_; std::unordered_map device_comm_map_; std::unordered_map device_stream_map_; - std::unordered_map global_group_rank_map_; // global_rank : group_rank - - int world_size_ = 0; - - const std::string name_ = ""; - - bool is_main_process_ = false; }; #endif diff --git a/infini_train/include/nn/parallel/reduce_op_type.h b/infini_train/include/nn/parallel/reduce_op_type.h index 0178307e..d366a1f0 100644 --- a/infini_train/include/nn/parallel/reduce_op_type.h +++ b/infini_train/include/nn/parallel/reduce_op_type.h @@ -10,4 +10,5 @@ enum class ReduceOpType : int8_t { kMax, kAvg, }; + } // namespace infini_train::nn::parallel::function diff --git a/infini_train/include/nn/parallel/work.h b/infini_train/include/nn/parallel/work.h index 3f304b7a..1e11cc02 100644 --- a/infini_train/include/nn/parallel/work.h +++ b/infini_train/include/nn/parallel/work.h @@ -3,7 +3,6 @@ #include #include #include -#include #include #ifdef USE_CUDA @@ -44,7 +43,7 @@ class WorkNccl final : public Work { ~WorkNccl() override; bool WaitBlocking(std::chrono::milliseconds timeout = std::chrono::milliseconds::zero()) override; - bool WaitNonBlocking(); + bool WaitNonBlocking() override; bool IsCompleted() const override; bool IsSuccess() const override; diff --git a/infini_train/src/nn/parallel/parallel_functional.cc b/infini_train/src/nn/parallel/parallel_functional.cc index 41be205e..50408949 100644 --- a/infini_train/src/nn/parallel/parallel_functional.cc +++ b/infini_train/src/nn/parallel/parallel_functional.cc @@ -12,6 +12,34 @@ #include "infini_train/include/tensor.h" namespace infini_train::nn::parallel::function { + +std::shared_ptr AllReduce(const std::shared_ptr &tensor, ReduceOpType reduce_op, const ProcessGroup *pg, + bool async_op) { + auto device = tensor->GetDevice()->Type(); + if (pg == nullptr) { + pg = ProcessGroupFactory::Instance()->GetDefaultProcessGroup(); + } + return pg->AllReduce(tensor, reduce_op, async_op); +} + +std::shared_ptr AllGather(const std::shared_ptr &output, const std::shared_ptr &input, + const ProcessGroup *pg, bool async_op) { + auto device = output->GetDevice()->Type(); + if (pg == nullptr) { + pg = ProcessGroupFactory::Instance()->GetDefaultProcessGroup(); + } + return pg->AllGather(output, input, async_op); +} + +std::shared_ptr ReduceScatter(const std::shared_ptr &output, const std::shared_ptr &input, + ReduceOpType reduce_op, const ProcessGroup *pg, bool async_op) { + auto device = output->GetDevice()->Type(); + if (pg == nullptr) { + pg = ProcessGroupFactory::Instance()->GetDefaultProcessGroup(); + } + return pg->ReduceScatter(output, input, reduce_op, async_op); +} + std::vector>> Scatter(const std::vector> &input_tensors, const std::vector &devices, int dim) { std::vector>> output_tensors; @@ -34,34 +62,6 @@ std::vector> Gather(const std::vector(target_device, dim)->Apply(gather_tensors); } -void AllReduce(const std::shared_ptr &tensor, ReduceOpType reduce_op, const ProcessGroup *pg) { - // TODO(dcj): use no_grad mode later - auto device = tensor->GetDevice()->Type(); - if (pg == nullptr) { - pg = ProcessGroupFactory::Instance()->GetDefaultProcessGroup(); - } - pg->AllReduce(tensor, reduce_op); -} - -void AllGather(const std::shared_ptr &output, const std::shared_ptr &input, const ProcessGroup *pg) { - // TODO(zbl): use no_grad mode later - auto device = output->GetDevice()->Type(); - if (pg == nullptr) { - pg = ProcessGroupFactory::Instance()->GetDefaultProcessGroup(); - } - pg->AllGather(output, input); -} - -void ReduceScatter(const std::shared_ptr &output, const std::shared_ptr &input, ReduceOpType reduce_op, - const ProcessGroup *pg) { - // TODO(zbl): use no_grad mode later - auto device = output->GetDevice()->Type(); - if (pg == nullptr) { - pg = ProcessGroupFactory::Instance()->GetDefaultProcessGroup(); - } - pg->ReduceScatter(output, input, reduce_op); -} - std::vector>> BroadcastCoalescedReshape(const std::vector> &tensors, const std::vector &devices) { diff --git a/infini_train/src/nn/parallel/pp/pipeline_schedule.cc b/infini_train/src/nn/parallel/pp/pipeline_schedule.cc index d5702249..f6081f56 100644 --- a/infini_train/src/nn/parallel/pp/pipeline_schedule.cc +++ b/infini_train/src/nn/parallel/pp/pipeline_schedule.cc @@ -2,14 +2,12 @@ #include "infini_train/include/nn/parallel/pp/pipeline_schedule.h" #include -#include #include #include #include "glog/logging.h" #include "infini_train/include/autocast.h" -#include "infini_train/include/autograd/grad_mode.h" #include "infini_train/include/datatype.h" #include "infini_train/include/device.h" #include "infini_train/include/nn/init.h" diff --git a/infini_train/src/nn/parallel/pp/send_recv.cc b/infini_train/src/nn/parallel/pp/send_recv.cc index 6a24a0e7..bac71f0b 100644 --- a/infini_train/src/nn/parallel/pp/send_recv.cc +++ b/infini_train/src/nn/parallel/pp/send_recv.cc @@ -63,7 +63,7 @@ std::vector> ISend::Forward(const std::vectorGet(GetPipelineParallelProcessGroupName(input_device_->rank().GlobalRank())); - pp_group->NcclSend(input_tensors, peer_rank_); + pp_group->Send(input_tensors, peer_rank_, false); return input_tensors; } @@ -79,14 +79,16 @@ std::vector> ISend::Backward(const std::vectorGet(GetPipelineParallelProcessGroupName(input_device_->rank().GlobalRank())); - return pp_group->NcclRecv(recv_tensors, peer_rank_); + pp_group->Recv(recv_tensors, peer_rank_, false); + + return recv_tensors; } std::vector> IRecv::Forward(const std::vector> &recv_tensors) { CHECK_NOTNULL(src_device_); auto pp_group = ProcessGroupFactory::Instance()->Get(GetPipelineParallelProcessGroupName(src_device_->rank().GlobalRank())); - pp_group->NcclRecv(recv_tensors, peer_rank_); + pp_group->Recv(recv_tensors, peer_rank_, false); return recv_tensors; } @@ -102,7 +104,10 @@ void IRecv::SetupContext(const std::vector> &input_tenso std::vector> IRecv::Backward(const std::vector> &grad_outputs) { auto pp_group = ProcessGroupFactory::Instance()->Get(GetPipelineParallelProcessGroupName(cur_device_->rank().GlobalRank())); - return pp_group->NcclSend(grad_outputs, peer_rank_); + + pp_group->Send(grad_outputs, peer_rank_, false); + + return grad_outputs; } } // namespace functions diff --git a/infini_train/src/nn/parallel/process_group.cc b/infini_train/src/nn/parallel/process_group.cc index bb386a75..50a75d48 100644 --- a/infini_train/src/nn/parallel/process_group.cc +++ b/infini_train/src/nn/parallel/process_group.cc @@ -86,9 +86,13 @@ void CleanupNcclIdFile(const std::string &pg_name) { namespace infini_train::nn::parallel { +int ProcessGroup::GetGroupRank(int global_rank) const { return global_group_rank_map_.at(global_rank); } + +ProcessGroup::ProcessGroup(int world_size, const std::string &name) : world_size_(world_size), name_(name) {} + #ifdef USE_NCCL -ProcessGroup::ProcessGroup(const std::string &process_group_name, const std::vector &ranks) - : world_size_(ranks.size()), name_(process_group_name) { +ProcessGroupNCCL::ProcessGroupNCCL(const std::string &process_group_name, const std::vector &ranks) + : ProcessGroup(ranks.size(), process_group_name) { int current_device = -1; CUDA_CHECK(cudaGetDevice(¤t_device)); @@ -102,7 +106,7 @@ ProcessGroup::ProcessGroup(const std::string &process_group_name, const std::vec CUDA_CHECK(cudaSetDevice(current_device)); } -ProcessGroup::~ProcessGroup() { +ProcessGroupNCCL::~ProcessGroupNCCL() { if (is_main_process_) { CleanupNcclIdFile(name_); } @@ -119,7 +123,7 @@ ProcessGroup::~ProcessGroup() { } } -void ProcessGroup::InitSingleProcess(const std::vector &ranks) { +void ProcessGroupNCCL::InitSingleProcess(const std::vector &ranks) { comms_.resize(world_size_); NCCL_CHECK(ncclCommInitAll(comms_.data(), world_size_, ranks.data())); @@ -131,7 +135,7 @@ void ProcessGroup::InitSingleProcess(const std::vector &ranks) { } } -void ProcessGroup::InitMultiProcess(const std::vector &ranks) { +void ProcessGroupNCCL::InitMultiProcess(const std::vector &ranks) { int n_threads = global::GetNthreadPerProc(); int global_proc_rank = global::GetGlobalProcRank(); int lower_rank = global_proc_rank * n_threads; @@ -170,7 +174,7 @@ void ProcessGroup::InitMultiProcess(const std::vector &ranks) { NCCL_CHECK(ncclGroupEnd()); } -void ProcessGroup::InitStreams() { +void ProcessGroupNCCL::InitStreams() { int device_size = devices_.size(); comm_streams_.resize(device_size); @@ -183,41 +187,196 @@ void ProcessGroup::InitStreams() { } } -int ProcessGroup::GetGroupRank(int global_rank) const { return global_group_rank_map_.at(global_rank); } - -void ProcessGroup::AllReduce(const std::shared_ptr &tensor, function::ReduceOpType reduce_op) const { +std::shared_ptr ProcessGroupNCCL::AllReduce(const std::shared_ptr &tensor, + function::ReduceOpType reduce_op, bool async_op) const { void *buffer = tensor->DataPtr(); - const auto *device = dynamic_cast(tensor->GetDevice()); + device->SetDevice(); + auto comm = device_comm_map_.at(device); - device->SetDevice(); + cudaStream_t compute_stream = device->Stream(); + cudaStream_t comm_stream = device_stream_map_.at(device); + + auto work = std::make_shared(device, comm); + + cudaEvent_t ready_event = reinterpret_cast(work->ready_event()); + cudaEvent_t done_event = reinterpret_cast(work->done_event()); + + CUDA_CHECK(cudaEventRecord(ready_event, compute_stream)); + CUDA_CHECK(cudaStreamWaitEvent(comm_stream, ready_event, 0)); + + // Perform NcclAllReduce on comm stream NCCL_CHECK(ncclAllReduce(buffer, buffer, tensor->NumElements(), kNcclDtypeMap.at(tensor->Dtype()), - kNcclReduceOpMap.at(reduce_op), comm, device->Stream())); + kNcclReduceOpMap.at(reduce_op), comm, comm_stream)); + + CUDA_CHECK(cudaEventRecord(done_event, comm_stream)); + + if (async_op) { + return std::move(work); + } else { + work->WaitNonBlocking(); + return nullptr; + } } -void ProcessGroup::AllGather(const std::shared_ptr &output, const std::shared_ptr &input) const { +std::shared_ptr ProcessGroupNCCL::AllGather(const std::shared_ptr &output, + const std::shared_ptr &input, bool async_op) const { const auto *device = dynamic_cast(input->GetDevice()); auto comm = device_comm_map_.at(device); device->SetDevice(); + + cudaStream_t compute_stream = device->Stream(); + cudaStream_t comm_stream = device_stream_map_.at(device); + + auto work = std::make_shared(device, comm); + + cudaEvent_t ready_event = reinterpret_cast(work->ready_event()); + cudaEvent_t done_event = reinterpret_cast(work->done_event()); + + CUDA_CHECK(cudaEventRecord(ready_event, compute_stream)); + CUDA_CHECK(cudaStreamWaitEvent(comm_stream, ready_event, 0)); + NCCL_CHECK(ncclAllGather(input->DataPtr(), output->DataPtr(), input->NumElements(), - kNcclDtypeMap.at(input->Dtype()), comm, device->Stream())); + kNcclDtypeMap.at(input->Dtype()), comm, comm_stream)); + + CUDA_CHECK(cudaEventRecord(done_event, comm_stream)); + + if (async_op) { + return std::move(work); + } else { + work->WaitNonBlocking(); + return nullptr; + } } -void ProcessGroup::ReduceScatter(const std::shared_ptr &output, const std::shared_ptr &input, - function::ReduceOpType reduce_op) const { +std::shared_ptr ProcessGroupNCCL::ReduceScatter(const std::shared_ptr &output, + const std::shared_ptr &input, + function::ReduceOpType reduce_op, bool async_op) const { const auto *device = dynamic_cast(input->GetDevice()); auto comm = device_comm_map_.at(device); device->SetDevice(); + + cudaStream_t compute_stream = device->Stream(); + cudaStream_t comm_stream = device_stream_map_.at(device); + + auto work = std::make_shared(device, comm); + + cudaEvent_t ready_event = reinterpret_cast(work->ready_event()); + cudaEvent_t done_event = reinterpret_cast(work->done_event()); + + CUDA_CHECK(cudaEventRecord(ready_event, compute_stream)); + CUDA_CHECK(cudaStreamWaitEvent(comm_stream, ready_event, 0)); + NCCL_CHECK(ncclReduceScatter(input->DataPtr(), output->DataPtr(), output->NumElements(), - kNcclDtypeMap.at(input->Dtype()), kNcclReduceOpMap.at(reduce_op), comm, - device->Stream())); + kNcclDtypeMap.at(input->Dtype()), kNcclReduceOpMap.at(reduce_op), comm, comm_stream)); + + CUDA_CHECK(cudaEventRecord(done_event, comm_stream)); + + if (async_op) { + return std::move(work); + } else { + work->WaitNonBlocking(); + return nullptr; + } +} + +std::shared_ptr ProcessGroupNCCL::Send(std::vector> tensors, int dest_rank, + bool async_op) const { + CHECK_GT(tensors.size(), 0); + const auto *device = dynamic_cast(tensors[0]->GetDevice()); + auto comm = device_comm_map_.at(device); + + device->SetDevice(); + + cudaStream_t compute_stream = device->Stream(); + cudaStream_t comm_stream = device_stream_map_.at(device); + + auto work = std::make_shared(device, comm); + + cudaEvent_t ready_event = reinterpret_cast(work->ready_event()); + cudaEvent_t done_event = reinterpret_cast(work->done_event()); + + CUDA_CHECK(cudaEventRecord(ready_event, compute_stream)); + CUDA_CHECK(cudaStreamWaitEvent(comm_stream, ready_event, 0)); + + for (int i = 0; i < tensors.size(); ++i) { + auto tensor = tensors[i]; + CHECK_NOTNULL(tensor); + + CHECK_EQ(device, tensor->GetDevice()); + + auto dtype = tensor->Dtype(); + auto nccl_dtype = kNcclDtypeMap.at(dtype); + auto count = tensor->NumElements(); + void *buffer = tensor->DataPtr(); + CHECK_NOTNULL(buffer); + + NCCL_CHECK(ncclSend(buffer, count, nccl_dtype, dest_rank, comm, comm_stream)); + } + + CUDA_CHECK(cudaEventRecord(done_event, comm_stream)); + + if (async_op) { + return std::move(work); + } else { + work->WaitNonBlocking(); + return nullptr; + } + + return std::move(work); +} + +std::shared_ptr ProcessGroupNCCL::Recv(std::vector> tensors, int src_rank, + bool async_op) const { + CHECK_GT(tensors.size(), 0); + const auto *device = dynamic_cast(tensors[0]->GetDevice()); + auto comm = device_comm_map_.at(device); + + device->SetDevice(); + + cudaStream_t compute_stream = device->Stream(); + cudaStream_t comm_stream = device_stream_map_.at(device); + + auto work = std::make_shared(device, comm); + + cudaEvent_t ready_event = reinterpret_cast(work->ready_event()); + cudaEvent_t done_event = reinterpret_cast(work->done_event()); + + CUDA_CHECK(cudaEventRecord(ready_event, compute_stream)); + CUDA_CHECK(cudaStreamWaitEvent(comm_stream, ready_event, 0)); + + for (int i = 0; i < tensors.size(); ++i) { + auto tensor = tensors[i]; + CHECK_NOTNULL(tensor); + + CHECK_EQ(device, tensor->GetDevice()); + + auto dtype = tensor->Dtype(); + auto nccl_dtype = kNcclDtypeMap.at(dtype); + auto count = tensor->NumElements(); + void *buffer = tensor->DataPtr(); + CHECK_NOTNULL(buffer); + + NCCL_CHECK(ncclRecv(buffer, count, nccl_dtype, src_rank, comm, compute_stream)); + } + + CUDA_CHECK(cudaEventRecord(done_event, comm_stream)); + + if (async_op) { + return std::move(work); + } else { + work->WaitNonBlocking(); + return nullptr; + } + + return std::move(work); } std::vector> -ProcessGroup::BroadCast(const std::vector> &input_tensors) const { +ProcessGroupNCCL::BroadCast(const std::vector> &input_tensors) const { std::vector> outputs; std::vector streams; std::vector comms; @@ -263,8 +422,8 @@ ProcessGroup::BroadCast(const std::vector> &input_tensor } std::vector> -ProcessGroup::ReduceAddCoalesced(const std::vector>> &grads, - const Device *destination) const { +ProcessGroupNCCL::ReduceAddCoalesced(const std::vector>> &grads, + const Device *destination) const { // grads: [devices, tensors] std::vector> outputs; std::vector streams; @@ -308,8 +467,8 @@ ProcessGroup::ReduceAddCoalesced(const std::vector> ProcessGroup::Scatter(const std::shared_ptr &tensor, - std::vector devices, int64_t dim) const { +std::vector> ProcessGroupNCCL::Scatter(const std::shared_ptr &tensor, + std::vector devices, int64_t dim) const { std::vector> outputs; std::vector> split_tensors = tensor->Split(tensor->Dims()[dim] / devices.size(), dim); std::vector streams; @@ -343,8 +502,8 @@ std::vector> ProcessGroup::Scatter(const std::shared_ptr return outputs; } -std::shared_ptr ProcessGroup::Gather(const std::vector> &tensors, - const Device *destination, int64_t dim) const { +std::shared_ptr ProcessGroupNCCL::Gather(const std::vector> &tensors, + const Device *destination, int64_t dim) const { std::vector> outouts; int64_t num_devices = tensors.size(); auto dtype = tensors[0]->Dtype(); @@ -395,86 +554,6 @@ std::shared_ptr ProcessGroup::Gather(const std::vector> ProcessGroup::NcclSend(std::vector> tensors, - int dest_rank) const { - for (int i = 0; i < tensors.size(); ++i) { - auto tensor = tensors[i]; - CHECK_NOTNULL(tensor); - - auto device = tensor->GetDevice(); - device->SetDevice(); - - cudaStream_t stream = dynamic_cast(device)->Stream(); - ncclComm_t comm = device_comm_map_.at(device); - - CHECK_NE(dest_rank, -1) << "Destination device not found in input tensors's devices"; - - auto dtype = tensor->Dtype(); - auto nccl_dtype = kNcclDtypeMap.at(dtype); - auto count = tensor->NumElements(); - void *buffer = tensor->DataPtr(); - CHECK_NOTNULL(buffer); - - NCCL_CHECK(ncclSend(buffer, count, nccl_dtype, dest_rank, comm, stream)); - } - return tensors; -} - -std::vector> ProcessGroup::NcclRecv(std::vector> tensors, - int src_rank) const { - for (int i = 0; i < tensors.size(); ++i) { - auto tensor = tensors[i]; - CHECK_NOTNULL(tensor); - - auto device = tensor->GetDevice(); - device->SetDevice(); - - cudaStream_t stream = dynamic_cast(device)->Stream(); - ncclComm_t comm = device_comm_map_.at(device); - - CHECK_NE(src_rank, -1) << "Source device not found in input devices"; - - auto dtype = tensor->Dtype(); - auto nccl_dtype = kNcclDtypeMap.at(dtype); - auto count = tensor->NumElements(); - void *buffer = tensor->DataPtr(); - CHECK_NOTNULL(buffer); - - NCCL_CHECK(ncclRecv(buffer, count, nccl_dtype, src_rank, comm, stream)); - } - return tensors; -} - -std::shared_ptr ProcessGroup::AllReduceAsync(const std::shared_ptr &tensor, - function::ReduceOpType reduce_op) const { - void *buffer = tensor->DataPtr(); - const auto *device = dynamic_cast(tensor->GetDevice()); - device->SetDevice(); - - auto comm = device_comm_map_.at(device); - - cudaStream_t compute_stream = device->Stream(); - cudaStream_t comm_stream = device_stream_map_.at(device); - - auto work = std::make_shared(device, comm); - - cudaEvent_t ready_event = reinterpret_cast(work->ready_event()); - cudaEvent_t done_event = reinterpret_cast(work->done_event()); - - CUDA_CHECK(cudaEventRecord(ready_event, compute_stream)); - CUDA_CHECK(cudaStreamWaitEvent(comm_stream, ready_event, 0)); - - // Perform NcclAllReduce on comm stream - NCCL_CHECK(ncclAllReduce(buffer, buffer, tensor->NumElements(), kNcclDtypeMap.at(tensor->Dtype()), - kNcclReduceOpMap.at(reduce_op), comm, comm_stream)); - - CUDA_CHECK(cudaEventRecord(done_event, comm_stream)); - - // Do not let compute stream wait for done event here - return std::move(work); -} - #endif ProcessGroupFactory *ProcessGroupFactory::Instance() { @@ -492,11 +571,13 @@ ProcessGroupFactory *ProcessGroupFactory::Instance() { const ProcessGroup *ProcessGroupFactory::GetOrCreate(const std::string &name, int comm_size) { std::vector device_indices(comm_size); std::iota(device_indices.begin(), device_indices.end(), 0); - return GetOrCreate(name, [&]() { return std::make_unique(name, device_indices); }); + // TODO(dcj): create device-specific ProcessGroup based on the registered device later + return GetOrCreate(name, [&]() { return std::make_unique(name, device_indices); }); } const ProcessGroup *ProcessGroupFactory::GetOrCreate(const std::string &name, const std::vector &device_indices) { - return GetOrCreate(name, [&]() { return std::make_unique(name, device_indices); }); + // TODO(dcj): create device-specific ProcessGroup based on the registered device later + return GetOrCreate(name, [&]() { return std::make_unique(name, device_indices); }); } const ProcessGroup *ProcessGroupFactory::Get(const std::string &name) const { diff --git a/infini_train/src/nn/parallel/reducer.cc b/infini_train/src/nn/parallel/reducer.cc index 32542969..0cdd7703 100644 --- a/infini_train/src/nn/parallel/reducer.cc +++ b/infini_train/src/nn/parallel/reducer.cc @@ -4,7 +4,6 @@ #include #include #include -#include #ifdef USE_CUDA #include @@ -13,7 +12,6 @@ #include "glog/logging.h" #include "infini_train/include/autograd/function_hook.h" -#include "infini_train/include/common/cuda/common_cuda.h" #include "infini_train/include/nn/parallel/utils.h" #include "infini_train/include/nn/parallel/work.h" #include "infini_train/include/tensor.h" @@ -421,7 +419,7 @@ void Reducer::FinalizeBucketDense(size_t bucket_index) { // FIXME(zbl): support custom hook later LOG(FATAL) << "Custom hook is not supported now"; } else { - bucket.work = ddp_pg->AllReduceAsync(bucket.contents, function::ReduceOpType::kAvg); + bucket.work = ddp_pg->AllReduce(bucket.contents, function::ReduceOpType::kAvg, true); } } diff --git a/infini_train/src/nn/parallel/tensor_parallel.cc b/infini_train/src/nn/parallel/tensor_parallel.cc index bc90a6d2..b91028df 100644 --- a/infini_train/src/nn/parallel/tensor_parallel.cc +++ b/infini_train/src/nn/parallel/tensor_parallel.cc @@ -15,6 +15,7 @@ #include "infini_train/include/nn/parallel/global.h" #include "infini_train/include/nn/parallel/parallel_functional.h" #include "infini_train/include/nn/parallel/utils.h" +#include "infini_train/include/nn/parallel/work.h" #include "infini_train/include/tensor.h" namespace infini_train::nn::parallel { @@ -41,7 +42,7 @@ std::shared_ptr GatherAlongFirstDim(const std::shared_ptr &tenso output_shape[0] *= world_size; auto gathered_output = std::make_shared(output_shape, tensor->Dtype(), device); - tp_group->AllGather(gathered_output, tensor); + tp_group->AllGather(gathered_output, tensor, false); return gathered_output; } @@ -61,7 +62,7 @@ std::shared_ptr GatherAlongLastDim(const std::shared_ptr &tensor output_shape[0] *= world_size; auto gathered_output = std::make_shared(output_shape, tensor->Dtype(), device); - tp_group->AllGather(gathered_output, tensor); + tp_group->AllGather(gathered_output, tensor, false); // AllGather gather along dim 0 by default auto output_list = gathered_output->Split(tensor->Dims()[0], 0); @@ -102,7 +103,7 @@ std::shared_ptr Reduce(const std::shared_ptr &tensor) { auto output = std::make_shared(*tensor); - tp_group->AllReduce(output, function::ReduceOpType::kSum); + tp_group->AllReduce(output, function::ReduceOpType::kSum, false); return output; } @@ -124,7 +125,8 @@ std::shared_ptr ReduceScatterAlongFirstDim(const std::shared_ptr auto output = std::make_shared(output_shape, tensor->Dtype(), device); - tp_group->ReduceScatter(output, tensor, function::ReduceOpType::kSum); + tp_group->ReduceScatter(output, tensor, function::ReduceOpType::kSum, false); + return output; } @@ -463,7 +465,7 @@ VocabParallelCrossEntropy::Forward(const std::vector> &i auto local_max = logits_masked->Max(-1); auto global_max = local_max; if (tp_size > 1) { - tp_group->AllReduce(global_max, function::ReduceOpType::kMax); + tp_group->AllReduce(global_max, function::ReduceOpType::kMax, false); } auto shifted = logits_masked->Sub(global_max->Unsqueeze(-1));