Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 10 additions & 8 deletions infini_train/include/nn/parallel/parallel_functional.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,22 @@ class Module;
} // namespace infini_train

namespace infini_train::nn::parallel::function {

std::shared_ptr<Work> AllReduce(const std::shared_ptr<Tensor> &tensor, ReduceOpType reduce_op,
const ProcessGroup *pg = nullptr, bool async_op = false);

std::shared_ptr<Work> AllGather(const std::shared_ptr<Tensor> &output, const std::shared_ptr<Tensor> &input,
const ProcessGroup *pg = nullptr, bool async_op = false);

std::shared_ptr<Work> ReduceScatter(const std::shared_ptr<Tensor> &output, const std::shared_ptr<Tensor> &input,
ReduceOpType reduce_op, const ProcessGroup *pg = nullptr, bool async_op = false);

std::vector<std::vector<std::shared_ptr<Tensor>>> Scatter(const std::vector<std::shared_ptr<Tensor>> &input_tensors,
const std::vector<const Device *> &device_ids, int dim);

std::vector<std::shared_ptr<Tensor>> Gather(const std::vector<std::vector<std::shared_ptr<Tensor>>> &outputs,
const Device *target_device, int dim);

void AllReduce(const std::shared_ptr<Tensor> &tensor, ReduceOpType reduce_op, const ProcessGroup *pg = nullptr);

void AllGather(const std::shared_ptr<Tensor> &output, const std::shared_ptr<Tensor> &input,
const ProcessGroup *pg = nullptr);

void ReduceScatter(const std::shared_ptr<Tensor> &output, const std::shared_ptr<Tensor> &input, ReduceOpType reduce_op,
const ProcessGroup *pg = nullptr);

std::vector<std::vector<std::shared_ptr<Tensor>>>
BroadcastCoalescedReshape(const std::vector<std::shared_ptr<Tensor>> &tensors,
const std::vector<const Device *> &devices);
Expand Down
107 changes: 79 additions & 28 deletions infini_train/include/nn/parallel/process_group.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,40 +28,99 @@ class Work;

namespace infini_train::nn::parallel {

#ifdef USE_NCCL
class ProcessGroup {
public:
explicit ProcessGroup(const std::string &process_group_name, const std::vector<int> &device_indices);
virtual int GetGroupRank(int global_rank) const;

~ProcessGroup();
// Asynchronous communication APIs (Compute / Communication stream decoupled)
virtual std::shared_ptr<Work> AllReduce(const std::shared_ptr<Tensor> &tensor,
function::ReduceOpType reduce_op = function::ReduceOpType::kSum,
bool async_op = false) const
= 0;

int GetGroupRank(int global_rank) const;
virtual std::shared_ptr<Work> AllGather(const std::shared_ptr<Tensor> &output, const std::shared_ptr<Tensor> &input,
bool async_op = false) const
= 0;

// Communication operations
void AllReduce(const std::shared_ptr<Tensor> &tensor, function::ReduceOpType reduce_op) const;
virtual std::shared_ptr<Work>
ReduceScatter(const std::shared_ptr<Tensor> &output, const std::shared_ptr<Tensor> &input,
function::ReduceOpType reduce_op = function::ReduceOpType::kSum, bool async_op = false) const
= 0;

void AllGather(const std::shared_ptr<Tensor> &output, const std::shared_ptr<Tensor> &input) const;
virtual std::shared_ptr<Work> Send(std::vector<std::shared_ptr<Tensor>> tensors, int dest_rank,
bool async_op = false) const
= 0;

void ReduceScatter(const std::shared_ptr<Tensor> &output, const std::shared_ptr<Tensor> &input,
function::ReduceOpType reduce_op) const;
virtual std::shared_ptr<Work> Recv(std::vector<std::shared_ptr<Tensor>> tensors, int src_rank,
bool async_op = false) const
= 0;

std::vector<std::shared_ptr<Tensor>> BroadCast(const std::vector<std::shared_ptr<Tensor>> &input_tensors) const;
// Legacy communication APIs (Single-stream)
virtual std::vector<std::shared_ptr<Tensor>>
BroadCast(const std::vector<std::shared_ptr<Tensor>> &input_tensors) const = 0;

std::vector<std::shared_ptr<Tensor>>
ReduceAddCoalesced(const std::vector<std::vector<std::shared_ptr<Tensor>>> &grads, const Device *destination) const;
virtual std::vector<std::shared_ptr<Tensor>>
ReduceAddCoalesced(const std::vector<std::vector<std::shared_ptr<Tensor>>> &grads, const Device *destination) const
= 0;

std::vector<std::shared_ptr<Tensor>> Scatter(const std::shared_ptr<Tensor> &tensor,
std::vector<const Device *> devices, int64_t dim) const;
virtual std::vector<std::shared_ptr<Tensor>> Scatter(const std::shared_ptr<Tensor> &tensor,
std::vector<const Device *> devices, int64_t dim) const
= 0;

std::shared_ptr<Tensor> Gather(const std::vector<std::shared_ptr<Tensor>> &tensors, const Device *destination,
int64_t dim) const;
virtual std::shared_ptr<Tensor> Gather(const std::vector<std::shared_ptr<Tensor>> &tensors,
const Device *destination, int64_t dim) const
= 0;

protected:
ProcessGroup(int world_size, const std::string &name);

std::vector<const Device *> devices_;

std::unordered_map<int, int> global_group_rank_map_; // global_rank : group_rank

int world_size_ = 0;

const std::string name_ = "";

bool is_main_process_ = false;
};

#ifdef USE_NCCL
class ProcessGroupNCCL final : public ProcessGroup {
public:
explicit ProcessGroupNCCL(const std::string &process_group_name, const std::vector<int> &device_indices);

std::vector<std::shared_ptr<Tensor>> NcclSend(std::vector<std::shared_ptr<Tensor>> tensors, int dest_rank) const;
~ProcessGroupNCCL();

std::vector<std::shared_ptr<Tensor>> NcclRecv(std::vector<std::shared_ptr<Tensor>> tensors, int src_rank) const;
// Asynchronous communication APIs (Compute / Communication stream decoupled)
std::shared_ptr<Work> AllReduce(const std::shared_ptr<Tensor> &tensor, function::ReduceOpType reduce_op,
bool async_op) const override;

// Async communication functions
std::shared_ptr<Work> AllReduceAsync(const std::shared_ptr<Tensor> &tensor, function::ReduceOpType reduce_op) const;
std::shared_ptr<Work> AllGather(const std::shared_ptr<Tensor> &output, const std::shared_ptr<Tensor> &input,
bool async_op) const override;

std::shared_ptr<Work> ReduceScatter(const std::shared_ptr<Tensor> &output, const std::shared_ptr<Tensor> &input,
function::ReduceOpType reduce_op, bool async_op) const override;

std::shared_ptr<Work> Send(std::vector<std::shared_ptr<Tensor>> tensors, int dest_rank,
bool async_op) const override;

std::shared_ptr<Work> Recv(std::vector<std::shared_ptr<Tensor>> tensors, int src_rank,
bool async_op) const override;

// Legacy communication APIs (Single-stream)
std::vector<std::shared_ptr<Tensor>>
BroadCast(const std::vector<std::shared_ptr<Tensor>> &input_tensors) const override;

std::vector<std::shared_ptr<Tensor>>
ReduceAddCoalesced(const std::vector<std::vector<std::shared_ptr<Tensor>>> &grads,
const Device *destination) const override;

std::vector<std::shared_ptr<Tensor>> Scatter(const std::shared_ptr<Tensor> &tensor,
std::vector<const Device *> devices, int64_t dim) const override;

std::shared_ptr<Tensor> Gather(const std::vector<std::shared_ptr<Tensor>> &tensors, const Device *destination,
int64_t dim) const override;

private:
void InitSingleProcess(const std::vector<int> &ranks);
Expand All @@ -73,17 +132,9 @@ class ProcessGroup {
private:
std::vector<ncclComm_t> comms_;
std::vector<cudaStream_t> comm_streams_;
std::vector<const Device *> devices_;

std::unordered_map<const Device *, ncclComm_t> device_comm_map_;
std::unordered_map<const Device *, cudaStream_t> device_stream_map_;
std::unordered_map<int, int> global_group_rank_map_; // global_rank : group_rank

int world_size_ = 0;

const std::string name_ = "";

bool is_main_process_ = false;
};
#endif

Expand Down
1 change: 1 addition & 0 deletions infini_train/include/nn/parallel/reduce_op_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,5 @@ enum class ReduceOpType : int8_t {
kMax,
kAvg,
};

} // namespace infini_train::nn::parallel::function
3 changes: 1 addition & 2 deletions infini_train/include/nn/parallel/work.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
#include <atomic>
#include <chrono>
#include <exception>
#include <memory>
#include <mutex>

#ifdef USE_CUDA
Expand Down Expand Up @@ -44,7 +43,7 @@ class WorkNccl final : public Work {
~WorkNccl() override;

bool WaitBlocking(std::chrono::milliseconds timeout = std::chrono::milliseconds::zero()) override;
bool WaitNonBlocking();
bool WaitNonBlocking() override;

bool IsCompleted() const override;
bool IsSuccess() const override;
Expand Down
56 changes: 28 additions & 28 deletions infini_train/src/nn/parallel/parallel_functional.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,34 @@
#include "infini_train/include/tensor.h"

namespace infini_train::nn::parallel::function {

std::shared_ptr<Work> AllReduce(const std::shared_ptr<Tensor> &tensor, ReduceOpType reduce_op, const ProcessGroup *pg,
bool async_op) {
auto device = tensor->GetDevice()->Type();
if (pg == nullptr) {
pg = ProcessGroupFactory::Instance()->GetDefaultProcessGroup();
}
return pg->AllReduce(tensor, reduce_op, async_op);
}

std::shared_ptr<Work> AllGather(const std::shared_ptr<Tensor> &output, const std::shared_ptr<Tensor> &input,
const ProcessGroup *pg, bool async_op) {
auto device = output->GetDevice()->Type();
if (pg == nullptr) {
pg = ProcessGroupFactory::Instance()->GetDefaultProcessGroup();
}
return pg->AllGather(output, input, async_op);
}

std::shared_ptr<Work> ReduceScatter(const std::shared_ptr<Tensor> &output, const std::shared_ptr<Tensor> &input,
ReduceOpType reduce_op, const ProcessGroup *pg, bool async_op) {
auto device = output->GetDevice()->Type();
if (pg == nullptr) {
pg = ProcessGroupFactory::Instance()->GetDefaultProcessGroup();
}
return pg->ReduceScatter(output, input, reduce_op, async_op);
}

std::vector<std::vector<std::shared_ptr<Tensor>>> Scatter(const std::vector<std::shared_ptr<Tensor>> &input_tensors,
const std::vector<const Device *> &devices, int dim) {
std::vector<std::vector<std::shared_ptr<Tensor>>> output_tensors;
Expand All @@ -34,34 +62,6 @@ std::vector<std::shared_ptr<Tensor>> Gather(const std::vector<std::vector<std::s
return std::make_shared<autograd::Gather>(target_device, dim)->Apply(gather_tensors);
}

void AllReduce(const std::shared_ptr<Tensor> &tensor, ReduceOpType reduce_op, const ProcessGroup *pg) {
// TODO(dcj): use no_grad mode later
auto device = tensor->GetDevice()->Type();
if (pg == nullptr) {
pg = ProcessGroupFactory::Instance()->GetDefaultProcessGroup();
}
pg->AllReduce(tensor, reduce_op);
}

void AllGather(const std::shared_ptr<Tensor> &output, const std::shared_ptr<Tensor> &input, const ProcessGroup *pg) {
// TODO(zbl): use no_grad mode later
auto device = output->GetDevice()->Type();
if (pg == nullptr) {
pg = ProcessGroupFactory::Instance()->GetDefaultProcessGroup();
}
pg->AllGather(output, input);
}

void ReduceScatter(const std::shared_ptr<Tensor> &output, const std::shared_ptr<Tensor> &input, ReduceOpType reduce_op,
const ProcessGroup *pg) {
// TODO(zbl): use no_grad mode later
auto device = output->GetDevice()->Type();
if (pg == nullptr) {
pg = ProcessGroupFactory::Instance()->GetDefaultProcessGroup();
}
pg->ReduceScatter(output, input, reduce_op);
}

std::vector<std::vector<std::shared_ptr<Tensor>>>
BroadcastCoalescedReshape(const std::vector<std::shared_ptr<Tensor>> &tensors,
const std::vector<const Device *> &devices) {
Expand Down
2 changes: 0 additions & 2 deletions infini_train/src/nn/parallel/pp/pipeline_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,12 @@
#include "infini_train/include/nn/parallel/pp/pipeline_schedule.h"

#include <cstddef>
#include <cstdint>
#include <memory>
#include <vector>

#include "glog/logging.h"

#include "infini_train/include/autocast.h"
#include "infini_train/include/autograd/grad_mode.h"
#include "infini_train/include/datatype.h"
#include "infini_train/include/device.h"
#include "infini_train/include/nn/init.h"
Expand Down
13 changes: 9 additions & 4 deletions infini_train/src/nn/parallel/pp/send_recv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ std::vector<std::shared_ptr<Tensor>> ISend::Forward(const std::vector<std::share
auto pp_group
= ProcessGroupFactory::Instance()->Get(GetPipelineParallelProcessGroupName(input_device_->rank().GlobalRank()));

pp_group->NcclSend(input_tensors, peer_rank_);
pp_group->Send(input_tensors, peer_rank_, false);

return input_tensors;
}
Expand All @@ -79,14 +79,16 @@ std::vector<std::shared_ptr<Tensor>> ISend::Backward(const std::vector<std::shar
auto pp_group
= ProcessGroupFactory::Instance()->Get(GetPipelineParallelProcessGroupName(input_device_->rank().GlobalRank()));

return pp_group->NcclRecv(recv_tensors, peer_rank_);
pp_group->Recv(recv_tensors, peer_rank_, false);

return recv_tensors;
}

std::vector<std::shared_ptr<Tensor>> IRecv::Forward(const std::vector<std::shared_ptr<Tensor>> &recv_tensors) {
CHECK_NOTNULL(src_device_);
auto pp_group
= ProcessGroupFactory::Instance()->Get(GetPipelineParallelProcessGroupName(src_device_->rank().GlobalRank()));
pp_group->NcclRecv(recv_tensors, peer_rank_);
pp_group->Recv(recv_tensors, peer_rank_, false);

return recv_tensors;
}
Expand All @@ -102,7 +104,10 @@ void IRecv::SetupContext(const std::vector<std::shared_ptr<Tensor>> &input_tenso
std::vector<std::shared_ptr<Tensor>> IRecv::Backward(const std::vector<std::shared_ptr<Tensor>> &grad_outputs) {
auto pp_group
= ProcessGroupFactory::Instance()->Get(GetPipelineParallelProcessGroupName(cur_device_->rank().GlobalRank()));
return pp_group->NcclSend(grad_outputs, peer_rank_);

pp_group->Send(grad_outputs, peer_rank_, false);

return grad_outputs;
}
} // namespace functions

Expand Down
Loading