Skip to content

Commit be321ca

Browse files
feat: Add ParamAndGradBuffer related logic and DistributedOptimizer, support ZeRO-1
1 parent 5dffc46 commit be321ca

16 files changed

+1296
-53
lines changed

example/llama3/main.cc

Lines changed: 32 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include "infini_train/include/nn/modules/loss.h"
1414
#include "infini_train/include/nn/modules/module.h"
1515
#include "infini_train/include/nn/parallel/distributed_data_parallel.h"
16+
#include "infini_train/include/nn/parallel/distributed_optimizer.h"
1617
#include "infini_train/include/nn/parallel/parallel_functional.h"
1718
#include "infini_train/include/nn/parallel/pp/pipeline_parallel.h"
1819
#include "infini_train/include/nn/parallel/rank.h"
@@ -48,6 +49,7 @@ DEFINE_uint32(freq_generate_txt, 10, "frequency of text generation");
4849
DEFINE_uint32(text_length, 64, "the length of the generated text");
4950
// optimization
5051
DEFINE_double(learning_rate, 1e-5, "learning rate warmup iterations");
52+
DEFINE_bool(use_distributed_optimizer, false, "Whether to enable DistributedOptimizer(only take effects when DP>1)");
5153
// evaluation
5254
DEFINE_uint32(val_loss_every, 0, "every how many steps to evaluate val loss?");
5355
DEFINE_uint32(sample_every, 0, "how often to sample from the model?");
@@ -171,8 +173,11 @@ void Train(const nn::parallel::Rank &rank) {
171173
// before wrapping the model with DistributedDataParallel (DDP).
172174
// Otherwise, DDP’s gradient hooks may be lost because new parameter tensors
173175
// are created during the conversion.
176+
177+
// FIXME(zbl): set as argument
174178
if (ddp_world_size > 1) {
175-
model = std::make_shared<DistributedDataParallel>(model, rank.thread_rank());
179+
auto ddp_config = DistributedDataParallelConfig{.use_distributed_optimizer = FLAGS_use_distributed_optimizer};
180+
model = std::make_shared<DistributedDataParallel>(model, rank.thread_rank(), ddp_config);
176181
}
177182

178183
auto num_micro_batches = FLAGS_total_batch_size / (FLAGS_batch_size * FLAGS_sequence_length * ddp_world_size);
@@ -197,7 +202,14 @@ void Train(const nn::parallel::Rank &rank) {
197202
}
198203

199204
// TODO(dcj): support more complex optimizer later
200-
auto optimizer = optimizers::Adam(model->Parameters(), FLAGS_learning_rate);
205+
// auto optimizer = optimizers::Adam(model->Parameters(), FLAGS_learning_rate);
206+
auto optimizer_creator = optimizers::Adam::Create(FLAGS_learning_rate);
207+
std::shared_ptr<Optimizer> optimizer
208+
= FLAGS_use_distributed_optimizer ? std::make_unique<nn::parallel::DistributedOptimizer>(
209+
optimizer_creator, model->Parameters(),
210+
dynamic_cast<DistributedDataParallel *>(model.get())->param_grad_buffers(),
211+
dynamic_cast<DistributedDataParallel *>(model.get())->bucket_groups(), ddp_pg, ddp_world_size, ddp_rank)
212+
: optimizer_creator(model->Parameters());
201213

202214
auto train_iter = train_loader.begin();
203215
std::shared_ptr<nn::Module> loss_fn
@@ -213,13 +225,18 @@ void Train(const nn::parallel::Rank &rank) {
213225
{FLAGS_batch_size, FLAGS_sequence_length / sp_world_size, model_config.n_embd}};
214226

215227
model = std::make_shared<nn::parallel::PipelineParallel>(model, pp_world_size, num_micro_batches, shapes,
216-
pp_rank, std::make_shared<optimizers::Adam>(optimizer),
217-
rank.thread_rank());
228+
pp_rank, optimizer, rank.thread_rank());
218229
}
219230

231+
auto cuda_device = device->IsCUDA() ? dynamic_cast<const CudaDevice *>(device) : nullptr;
232+
220233
for (int step = 0; step < FLAGS_num_iteration + 1; ++step) {
221234
const bool last_step = step == FLAGS_num_iteration;
222235

236+
if (cuda_device) {
237+
cuda_device->ResetMemPoolHighWatermarks();
238+
}
239+
223240
const auto iter_start = std::chrono::high_resolution_clock::now();
224241

225242
// once in a while evaluate the validation dataset
@@ -246,7 +263,7 @@ void Train(const nn::parallel::Rank &rank) {
246263
float lossf = 0.0f;
247264
if (pp_world_size == 1) {
248265
// model->Train();
249-
optimizer.ZeroGrad();
266+
optimizer->ZeroGrad();
250267

251268
// if we are trying to overfit a single batch, we reset the loader here
252269
if (FLAGS_overfit_single_batch) {
@@ -284,7 +301,7 @@ void Train(const nn::parallel::Rank &rank) {
284301
LOG(INFO) << "Rank " << rank.GlobalRank() << ": finish backward";
285302
}
286303

287-
optimizer.Step();
304+
optimizer->Step();
288305
} else {
289306
auto [x, y] = *train_iter;
290307
// if we are trying to overfit a single batch, we reset the loader here by commenting out the line below
@@ -308,10 +325,16 @@ void Train(const nn::parallel::Rank &rank) {
308325
const double tps = FLAGS_total_batch_size / (duration_us / 1e6);
309326

310327
if (rank.IsLastRank()) {
311-
LOG(ERROR) << std::format("step {:4d}/{} | train loss {:.6f} | lr {:.2e} | ({:.2f} ms | {:.0f} tok/s, "
312-
"DP={}, TP={}, SP={}, PP={})",
328+
size_t used_mb = 0, reserved_mb = 0;
329+
if (cuda_device) {
330+
std::tie(used_mb, reserved_mb) = cuda_device->GetMemPoolPeakMB();
331+
}
332+
333+
LOG(ERROR) << std::format("step {:4d}/{} | train loss {:.6f} | lr {:.2e} | ({:.2f} ms | {:.0f} tok/s | "
334+
"peak used: {:5d} MB | peak reserved: {:5d} MB, DP={}, TP={}, SP={}, PP={})",
313335
step + 1, FLAGS_num_iteration, lossf, FLAGS_learning_rate, duration_us / 1e3f,
314-
tps, ddp_world_size, tp_world_size, sp_world_size, pp_world_size);
336+
tps, used_mb, reserved_mb, ddp_world_size, tp_world_size, sp_world_size,
337+
pp_world_size);
315338

316339
if ((step + 1) % FLAGS_freq_generate_txt == 0) {
317340
// FIXME(jym): to support PP

infini_train/include/device.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,9 @@ class CudaDevice : public Device {
6868

6969
nn::parallel::Rank rank() const override;
7070

71+
void ResetMemPoolHighWatermarks() const;
72+
std::pair<size_t, size_t> GetMemPoolPeakMB() const;
73+
7174
private:
7275
CudaDevice(int8_t index);
7376

infini_train/include/nn/parallel/distributed_data_parallel.h

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
#include <memory>
44

55
#include "infini_train/include/nn/modules/module.h"
6+
#include "infini_train/include/nn/parallel/distributed_data_parallel_config.h"
7+
#include "infini_train/include/nn/parallel/param_and_grad_buffer.h"
68
#include "infini_train/include/nn/parallel/reducer.h"
79

810
namespace infini_train {
@@ -14,13 +16,34 @@ namespace infini_train::nn::parallel {
1416

1517
class DistributedDataParallel : public nn::Module {
1618
public:
17-
DistributedDataParallel(std::shared_ptr<nn::Module> module, int device_id,
18-
const ReducerOptions &opts = ReducerOptions{});
19+
DistributedDataParallel(std::shared_ptr<nn::Module> module, int thread_rank,
20+
DistributedDataParallelConfig ddp_config = DistributedDataParallelConfig());
1921

2022
std::vector<std::shared_ptr<Tensor>> Forward(const std::vector<std::shared_ptr<Tensor>> &input_tensors) override;
2123

24+
DistributedDataParallelConfig ddp_config() const { return ddp_config_; }
25+
26+
const std::vector<std::shared_ptr<ParamAndGradBuffer>> &param_grad_buffers() const { return param_grad_buffers_; }
27+
28+
const std::vector<std::shared_ptr<ParamAndGradBucketGroup>> &bucket_groups() const { return bucket_groups_; }
29+
30+
private:
31+
void BuildParamAndGradBuffers();
32+
void RegisterBackwardHooks();
33+
void OnGradReady(const std::shared_ptr<Tensor> &param);
34+
2235
private:
2336
std::shared_ptr<Reducer> reducer_ = nullptr;
37+
38+
DistributedDataParallelConfig ddp_config_;
39+
const ProcessGroup *ddp_pg_ = nullptr;
40+
41+
std::vector<std::shared_ptr<ParamAndGradBuffer>> param_grad_buffers_;
42+
std::vector<std::shared_ptr<ParamAndGradBucketGroup>> bucket_groups_;
43+
std::unordered_map<Tensor *, std::shared_ptr<ParamAndGradBucketGroup>> param_to_bucket_group_;
44+
45+
std::atomic<size_t> num_params_ready_{0};
46+
size_t total_params_{0};
2447
};
2548

2649
} // namespace infini_train::nn::parallel
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
#pragma once
2+
3+
#include <limits>
4+
5+
namespace infini_train::nn::parallel {
6+
namespace {
7+
// Default bucket size in alignment with PyTorch
8+
constexpr int kFirstBucketCapMB = 1;
9+
constexpr int kNormalBucketCapMB = 25;
10+
} // namespace
11+
12+
class DistributedDataParallelConfig {
13+
public:
14+
// ======================================================
15+
// Reducer-related args
16+
// Ref: https://docs.pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html
17+
// ======================================================
18+
// Max capacity for each bucket(in MB).
19+
size_t first_bucket_cap_mb = kFirstBucketCapMB;
20+
size_t normal_bucket_cap_mb = kNormalBucketCapMB;
21+
22+
// When set true, map param.grad directly to the slice of bucket.flat(same address in memory) instead of memcpy.
23+
bool gradient_as_bucket_view = true;
24+
25+
// Whether to enable gradient bucketing.
26+
bool gradient_bucketing_enabled = true;
27+
28+
// ======================================================
29+
// DistributedOptimizer-related args
30+
// Ref:
31+
// https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/distributed/distributed_data_parallel_config.py
32+
// ======================================================
33+
// Whether to enable DistributedOptimizer (ZeRO-1 equivalent).
34+
// When set true:
35+
// 1) Gradients/params are managed by ParamAndGradBuffer and reduced in groups.
36+
// 2) The classic DDP reducer path is not used (i.e., disable reducer/bucketing in the DDP sense).
37+
bool use_distributed_optimizer = false;
38+
39+
// Whether to overlap gradient reduce-scatter/all-reduce with backward compute.
40+
// In this case, grad reduce is triggered immediately when a grad is ready or till all grads are ready.
41+
bool overlap_grad_reduce = true;
42+
43+
// Whether to overlap parameter all-gather with forward compute.
44+
bool overlap_param_gather = true;
45+
46+
// Whether to average values inside collectives (divide by world size) instead of summing.
47+
bool average_in_collective = true;
48+
49+
// Whether to check NaNs/Infs/unusually large in gradients before collectives.
50+
bool check_for_nan_in_grad = false;
51+
bool check_for_large_grads = false;
52+
53+
// Number of DistributedOptimizer instances.
54+
// Multiple DistOpt is used for building hierarchical collective groups for param/grad.
55+
int num_distributed_optimizer_instances = 1;
56+
57+
// Maximum number of parameters in each ParamAndGradBucket.
58+
// This is distinct from DDP Reducer's MB-based bucket caps.
59+
size_t bucket_size_in_elements = std::numeric_limits<size_t>::max();
60+
61+
// Whether to pad bucket sizes to improve NCCL bus bandwidth utilization.
62+
bool pad_buckets_for_high_nccl_busbw = false;
63+
};
64+
} // namespace infini_train::nn::parallel
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
#pragma once
2+
3+
#include <cstdint>
4+
#include <memory>
5+
#include <unordered_map>
6+
#include <vector>
7+
8+
#include "infini_train/include/nn/parallel/param_and_grad_buffer.h"
9+
#include "infini_train/include/optimizer.h"
10+
11+
namespace infini_train::nn::parallel {
12+
13+
class DistributedOptimizer final : public infini_train::Optimizer {
14+
public:
15+
DistributedOptimizer(OptimizerCreator inner_optimizer_creator,
16+
const std::vector<std::shared_ptr<Tensor>> &full_params,
17+
const std::vector<std::shared_ptr<ParamAndGradBuffer>> &buffers,
18+
const std::vector<std::shared_ptr<ParamAndGradBucketGroup>> &bucket_groups,
19+
const ProcessGroup *dp_pg, size_t dp_world_size, size_t ddp_rank);
20+
21+
void Step() override;
22+
23+
void ZeroGrad(bool set_to_none = true) override;
24+
25+
void StartGradSync();
26+
void FinishGradSync();
27+
28+
void StartParamSync(bool force_sync = false);
29+
void FinishParamSync(bool skip_next_bucket_dispatch = false);
30+
31+
private:
32+
void BuildShardParamsAndBindGrads();
33+
34+
private:
35+
// Inherit from DDP model
36+
std::vector<std::shared_ptr<ParamAndGradBuffer>> param_grad_buffers_;
37+
std::vector<std::shared_ptr<ParamAndGradBucketGroup>> bucket_groups_;
38+
39+
// DP info
40+
const ProcessGroup *dp_pg_;
41+
size_t dp_world_size_;
42+
size_t dp_rank_;
43+
44+
// shard params
45+
std::vector<std::shared_ptr<Tensor>> shard_params_;
46+
47+
// Base optimizer (SGD, Adam and etc.)
48+
OptimizerCreator creator_;
49+
std::shared_ptr<Optimizer> base_optimizer_;
50+
};
51+
52+
} // namespace infini_train::nn::parallel

0 commit comments

Comments
 (0)