1313#include " infini_train/include/nn/modules/loss.h"
1414#include " infini_train/include/nn/modules/module.h"
1515#include " infini_train/include/nn/parallel/distributed_data_parallel.h"
16+ #include " infini_train/include/nn/parallel/distributed_optimizer.h"
1617#include " infini_train/include/nn/parallel/parallel_functional.h"
1718#include " infini_train/include/nn/parallel/pp/pipeline_parallel.h"
1819#include " infini_train/include/nn/parallel/rank.h"
@@ -48,6 +49,7 @@ DEFINE_uint32(freq_generate_txt, 10, "frequency of text generation");
4849DEFINE_uint32 (text_length, 64 , " the length of the generated text" );
4950// optimization
5051DEFINE_double (learning_rate, 1e-5 , " learning rate warmup iterations" );
52+ DEFINE_bool (use_distributed_optimizer, false , " Whether to enable DistributedOptimizer(only take effects when DP>1)" );
5153// evaluation
5254DEFINE_uint32 (val_loss_every, 0 , " every how many steps to evaluate val loss?" );
5355DEFINE_uint32 (sample_every, 0 , " how often to sample from the model?" );
@@ -171,8 +173,11 @@ void Train(const nn::parallel::Rank &rank) {
171173 // before wrapping the model with DistributedDataParallel (DDP).
172174 // Otherwise, DDP’s gradient hooks may be lost because new parameter tensors
173175 // are created during the conversion.
176+
177+ // FIXME(zbl): set as argument
174178 if (ddp_world_size > 1 ) {
175- model = std::make_shared<DistributedDataParallel>(model, rank.thread_rank ());
179+ auto ddp_config = DistributedDataParallelConfig{.use_distributed_optimizer = FLAGS_use_distributed_optimizer};
180+ model = std::make_shared<DistributedDataParallel>(model, rank.thread_rank (), ddp_config);
176181 }
177182
178183 auto num_micro_batches = FLAGS_total_batch_size / (FLAGS_batch_size * FLAGS_sequence_length * ddp_world_size);
@@ -197,7 +202,14 @@ void Train(const nn::parallel::Rank &rank) {
197202 }
198203
199204 // TODO(dcj): support more complex optimizer later
200- auto optimizer = optimizers::Adam (model->Parameters (), FLAGS_learning_rate);
205+ // auto optimizer = optimizers::Adam(model->Parameters(), FLAGS_learning_rate);
206+ auto optimizer_creator = optimizers::Adam::Create (FLAGS_learning_rate);
207+ std::shared_ptr<Optimizer> optimizer
208+ = FLAGS_use_distributed_optimizer ? std::make_unique<nn::parallel::DistributedOptimizer>(
209+ optimizer_creator, model->Parameters (),
210+ dynamic_cast <DistributedDataParallel *>(model.get ())->param_grad_buffers (),
211+ dynamic_cast <DistributedDataParallel *>(model.get ())->bucket_groups (), ddp_pg, ddp_world_size, ddp_rank)
212+ : optimizer_creator (model->Parameters ());
201213
202214 auto train_iter = train_loader.begin ();
203215 std::shared_ptr<nn::Module> loss_fn
@@ -213,13 +225,18 @@ void Train(const nn::parallel::Rank &rank) {
213225 {FLAGS_batch_size, FLAGS_sequence_length / sp_world_size, model_config.n_embd }};
214226
215227 model = std::make_shared<nn::parallel::PipelineParallel>(model, pp_world_size, num_micro_batches, shapes,
216- pp_rank, std::make_shared<optimizers::Adam>(optimizer),
217- rank.thread_rank ());
228+ pp_rank, optimizer, rank.thread_rank ());
218229 }
219230
231+ auto cuda_device = device->IsCUDA () ? dynamic_cast <const CudaDevice *>(device) : nullptr ;
232+
220233 for (int step = 0 ; step < FLAGS_num_iteration + 1 ; ++step) {
221234 const bool last_step = step == FLAGS_num_iteration;
222235
236+ if (cuda_device) {
237+ cuda_device->ResetMemPoolHighWatermarks ();
238+ }
239+
223240 const auto iter_start = std::chrono::high_resolution_clock::now ();
224241
225242 // once in a while evaluate the validation dataset
@@ -246,7 +263,7 @@ void Train(const nn::parallel::Rank &rank) {
246263 float lossf = 0 .0f ;
247264 if (pp_world_size == 1 ) {
248265 // model->Train();
249- optimizer. ZeroGrad ();
266+ optimizer-> ZeroGrad ();
250267
251268 // if we are trying to overfit a single batch, we reset the loader here
252269 if (FLAGS_overfit_single_batch) {
@@ -284,7 +301,7 @@ void Train(const nn::parallel::Rank &rank) {
284301 LOG (INFO) << " Rank " << rank.GlobalRank () << " : finish backward" ;
285302 }
286303
287- optimizer. Step ();
304+ optimizer-> Step ();
288305 } else {
289306 auto [x, y] = *train_iter;
290307 // if we are trying to overfit a single batch, we reset the loader here by commenting out the line below
@@ -308,10 +325,16 @@ void Train(const nn::parallel::Rank &rank) {
308325 const double tps = FLAGS_total_batch_size / (duration_us / 1e6 );
309326
310327 if (rank.IsLastRank ()) {
311- LOG (ERROR) << std::format (" step {:4d}/{} | train loss {:.6f} | lr {:.2e} | ({:.2f} ms | {:.0f} tok/s, "
312- " DP={}, TP={}, SP={}, PP={})" ,
328+ size_t used_mb = 0 , reserved_mb = 0 ;
329+ if (cuda_device) {
330+ std::tie (used_mb, reserved_mb) = cuda_device->GetMemPoolPeakMB ();
331+ }
332+
333+ LOG (ERROR) << std::format (" step {:4d}/{} | train loss {:.6f} | lr {:.2e} | ({:.2f} ms | {:.0f} tok/s | "
334+ " peak used: {:5d} MB | peak reserved: {:5d} MB, DP={}, TP={}, SP={}, PP={})" ,
313335 step + 1 , FLAGS_num_iteration, lossf, FLAGS_learning_rate, duration_us / 1e3f,
314- tps, ddp_world_size, tp_world_size, sp_world_size, pp_world_size);
336+ tps, used_mb, reserved_mb, ddp_world_size, tp_world_size, sp_world_size,
337+ pp_world_size);
315338
316339 if ((step + 1 ) % FLAGS_freq_generate_txt == 0 ) {
317340 // FIXME(jym): to support PP
0 commit comments