Skip to content

Commit 5dffc46

Browse files
committed
fix: resolve review comments
1 parent 7be7741 commit 5dffc46

File tree

6 files changed

+25
-28
lines changed

6 files changed

+25
-28
lines changed

infini_train/include/nn/parallel/process_group.h

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -34,22 +34,25 @@ class ProcessGroup {
3434

3535
// Asynchronous communication APIs (Compute / Communication stream decoupled)
3636
virtual std::shared_ptr<Work> AllReduce(const std::shared_ptr<Tensor> &tensor,
37-
const function::AllreduceOptions &opts) const
37+
function::ReduceOpType reduce_op = function::ReduceOpType::kSum,
38+
bool async_op = false) const
3839
= 0;
3940

4041
virtual std::shared_ptr<Work> AllGather(const std::shared_ptr<Tensor> &output, const std::shared_ptr<Tensor> &input,
41-
bool async_op) const
42+
bool async_op = false) const
4243
= 0;
4344

44-
virtual std::shared_ptr<Work> ReduceScatter(const std::shared_ptr<Tensor> &output,
45-
const std::shared_ptr<Tensor> &input,
46-
const function::AllreduceOptions &opts) const
45+
virtual std::shared_ptr<Work>
46+
ReduceScatter(const std::shared_ptr<Tensor> &output, const std::shared_ptr<Tensor> &input,
47+
function::ReduceOpType reduce_op = function::ReduceOpType::kSum, bool async_op = false) const
4748
= 0;
4849

49-
virtual std::shared_ptr<Work> Send(std::vector<std::shared_ptr<Tensor>> tensors, int dest_rank, bool async_op) const
50+
virtual std::shared_ptr<Work> Send(std::vector<std::shared_ptr<Tensor>> tensors, int dest_rank,
51+
bool async_op = false) const
5052
= 0;
5153

52-
virtual std::shared_ptr<Work> Recv(std::vector<std::shared_ptr<Tensor>> tensors, int src_rank, bool async_op) const
54+
virtual std::shared_ptr<Work> Recv(std::vector<std::shared_ptr<Tensor>> tensors, int src_rank,
55+
bool async_op = false) const
5356
= 0;
5457

5558
// Legacy communication APIs (Single-stream)
@@ -90,14 +93,14 @@ class ProcessGroupNCCL final : public ProcessGroup {
9093
~ProcessGroupNCCL();
9194

9295
// Asynchronous communication APIs (Compute / Communication stream decoupled)
93-
std::shared_ptr<Work> AllReduce(const std::shared_ptr<Tensor> &tensor,
94-
const function::AllreduceOptions &opts) const override;
96+
std::shared_ptr<Work> AllReduce(const std::shared_ptr<Tensor> &tensor, function::ReduceOpType reduce_op,
97+
bool async_op) const override;
9598

9699
std::shared_ptr<Work> AllGather(const std::shared_ptr<Tensor> &output, const std::shared_ptr<Tensor> &input,
97100
bool async_op) const override;
98101

99102
std::shared_ptr<Work> ReduceScatter(const std::shared_ptr<Tensor> &output, const std::shared_ptr<Tensor> &input,
100-
const function::AllreduceOptions &opts) const override;
103+
function::ReduceOpType reduce_op, bool async_op) const override;
101104

102105
std::shared_ptr<Work> Send(std::vector<std::shared_ptr<Tensor>> tensors, int dest_rank,
103106
bool async_op) const override;

infini_train/include/nn/parallel/reduce_op_type.h

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,4 @@ enum class ReduceOpType : int8_t {
1111
kAvg,
1212
};
1313

14-
struct AllreduceOptions {
15-
ReduceOpType reduce_op_type = ReduceOpType::kSum;
16-
bool async_op = false;
17-
};
18-
1914
} // namespace infini_train::nn::parallel::function

infini_train/src/nn/parallel/parallel_functional.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ std::shared_ptr<Work> AllReduce(const std::shared_ptr<Tensor> &tensor, ReduceOpT
1919
if (pg == nullptr) {
2020
pg = ProcessGroupFactory::Instance()->GetDefaultProcessGroup();
2121
}
22-
return pg->AllReduce(tensor, {reduce_op, async_op});
22+
return pg->AllReduce(tensor, reduce_op, async_op);
2323
}
2424

2525
std::shared_ptr<Work> AllGather(const std::shared_ptr<Tensor> &output, const std::shared_ptr<Tensor> &input,
@@ -37,7 +37,7 @@ std::shared_ptr<Work> ReduceScatter(const std::shared_ptr<Tensor> &output, const
3737
if (pg == nullptr) {
3838
pg = ProcessGroupFactory::Instance()->GetDefaultProcessGroup();
3939
}
40-
return pg->ReduceScatter(output, input, {reduce_op, async_op});
40+
return pg->ReduceScatter(output, input, reduce_op, async_op);
4141
}
4242

4343
std::vector<std::vector<std::shared_ptr<Tensor>>> Scatter(const std::vector<std::shared_ptr<Tensor>> &input_tensors,

infini_train/src/nn/parallel/process_group.cc

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ void ProcessGroupNCCL::InitStreams() {
188188
}
189189

190190
std::shared_ptr<Work> ProcessGroupNCCL::AllReduce(const std::shared_ptr<Tensor> &tensor,
191-
const function::AllreduceOptions &opts) const {
191+
function::ReduceOpType reduce_op, bool async_op) const {
192192
void *buffer = tensor->DataPtr();
193193
const auto *device = dynamic_cast<const CudaDevice *>(tensor->GetDevice());
194194
device->SetDevice();
@@ -208,11 +208,11 @@ std::shared_ptr<Work> ProcessGroupNCCL::AllReduce(const std::shared_ptr<Tensor>
208208

209209
// Perform NcclAllReduce on comm stream
210210
NCCL_CHECK(ncclAllReduce(buffer, buffer, tensor->NumElements(), kNcclDtypeMap.at(tensor->Dtype()),
211-
kNcclReduceOpMap.at(opts.reduce_op_type), comm, comm_stream));
211+
kNcclReduceOpMap.at(reduce_op), comm, comm_stream));
212212

213213
CUDA_CHECK(cudaEventRecord(done_event, comm_stream));
214214

215-
if (opts.async_op) {
215+
if (async_op) {
216216
return std::move(work);
217217
} else {
218218
work->WaitNonBlocking();
@@ -253,7 +253,7 @@ std::shared_ptr<Work> ProcessGroupNCCL::AllGather(const std::shared_ptr<Tensor>
253253

254254
std::shared_ptr<Work> ProcessGroupNCCL::ReduceScatter(const std::shared_ptr<Tensor> &output,
255255
const std::shared_ptr<Tensor> &input,
256-
const function::AllreduceOptions &opts) const {
256+
function::ReduceOpType reduce_op, bool async_op) const {
257257
const auto *device = dynamic_cast<const CudaDevice *>(input->GetDevice());
258258
auto comm = device_comm_map_.at(device);
259259

@@ -271,12 +271,11 @@ std::shared_ptr<Work> ProcessGroupNCCL::ReduceScatter(const std::shared_ptr<Tens
271271
CUDA_CHECK(cudaStreamWaitEvent(comm_stream, ready_event, 0));
272272

273273
NCCL_CHECK(ncclReduceScatter(input->DataPtr(), output->DataPtr(), output->NumElements(),
274-
kNcclDtypeMap.at(input->Dtype()), kNcclReduceOpMap.at(opts.reduce_op_type), comm,
275-
comm_stream));
274+
kNcclDtypeMap.at(input->Dtype()), kNcclReduceOpMap.at(reduce_op), comm, comm_stream));
276275

277276
CUDA_CHECK(cudaEventRecord(done_event, comm_stream));
278277

279-
if (opts.async_op) {
278+
if (async_op) {
280279
return std::move(work);
281280
} else {
282281
work->WaitNonBlocking();

infini_train/src/nn/parallel/reducer.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -419,7 +419,7 @@ void Reducer::FinalizeBucketDense(size_t bucket_index) {
419419
// FIXME(zbl): support custom hook later
420420
LOG(FATAL) << "Custom hook is not supported now";
421421
} else {
422-
bucket.work = ddp_pg->AllReduce(bucket.contents, {function::ReduceOpType::kAvg, true});
422+
bucket.work = ddp_pg->AllReduce(bucket.contents, function::ReduceOpType::kAvg, true);
423423
}
424424
}
425425

infini_train/src/nn/parallel/tensor_parallel.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ std::shared_ptr<Tensor> Reduce(const std::shared_ptr<Tensor> &tensor) {
103103

104104
auto output = std::make_shared<Tensor>(*tensor);
105105

106-
tp_group->AllReduce(output, {function::ReduceOpType::kSum, false});
106+
tp_group->AllReduce(output, function::ReduceOpType::kSum, false);
107107
return output;
108108
}
109109

@@ -125,7 +125,7 @@ std::shared_ptr<Tensor> ReduceScatterAlongFirstDim(const std::shared_ptr<Tensor>
125125

126126
auto output = std::make_shared<Tensor>(output_shape, tensor->Dtype(), device);
127127

128-
tp_group->ReduceScatter(output, tensor, {function::ReduceOpType::kSum, false});
128+
tp_group->ReduceScatter(output, tensor, function::ReduceOpType::kSum, false);
129129

130130
return output;
131131
}
@@ -465,7 +465,7 @@ VocabParallelCrossEntropy::Forward(const std::vector<std::shared_ptr<Tensor>> &i
465465
auto local_max = logits_masked->Max(-1);
466466
auto global_max = local_max;
467467
if (tp_size > 1) {
468-
tp_group->AllReduce(global_max, {function::ReduceOpType::kMax, false});
468+
tp_group->AllReduce(global_max, function::ReduceOpType::kMax, false);
469469
}
470470
auto shifted = logits_masked->Sub(global_max->Unsqueeze(-1));
471471

0 commit comments

Comments
 (0)