diff --git a/example/gpt2/main.cc b/example/gpt2/main.cc index 14b25c4..4a34c46 100644 --- a/example/gpt2/main.cc +++ b/example/gpt2/main.cc @@ -188,15 +188,35 @@ void Train(const nn::parallel::Rank &rank) { LOG(FATAL) << "Rank " << rank.GlobalRank() << ": Datatype " << FLAGS_dtype << " not supported."; } - // NOTE(dcj): Complete all device (.to(device)) and dtype (.to(dtype)) conversions - // before wrapping the model with DistributedDataParallel (DDP). - // Otherwise, DDP’s gradient hooks may be lost because new parameter tensors - // are created during the conversion. - if (ddp_world_size > 1) { + auto num_micro_batches = FLAGS_total_batch_size / (FLAGS_batch_size * FLAGS_sequence_length * ddp_world_size); + + // TODO(dcj): support more complex optimizer later + auto optimizer = optimizers::SGD(model->Parameters(), FLAGS_learning_rate); + + if (pp_world_size > 1) { + // NOTE(dcj): To ensure that the tensor shapes at the pipeline stage boundaries remain correct + // when sequence parallelism (SP) is enabled, we need to divide by sp_world_size. + auto shapes = std::vector>{ + {FLAGS_batch_size, FLAGS_sequence_length / sp_world_size, model_config.n_embd}}; + + model = std::make_shared( + model, pp_world_size, num_micro_batches, shapes, pp_rank, std::make_shared(optimizer), + rank.thread_rank(), std::dynamic_pointer_cast(model)->GetChunkSize()); + if (ddp_world_size > 1) { + auto *mutable_chunks = dynamic_cast(model.get())->mutable_chunks(); + for (int chunk_id = 0; chunk_id < mutable_chunks->size(); ++chunk_id) { + (*mutable_chunks)[chunk_id] + = std::make_shared(mutable_chunks->at(chunk_id), rank.thread_rank()); + } + } + } else if (ddp_world_size > 1) { + // NOTE(dcj): Complete all device (.to(device)) and dtype (.to(dtype)) conversions + // before wrapping the model with DistributedDataParallel (DDP). + // Otherwise, DDP’s gradient hooks may be lost because new parameter tensors + // are created during the conversion. model = std::make_shared(model, rank.thread_rank()); } - auto num_micro_batches = FLAGS_total_batch_size / (FLAGS_batch_size * FLAGS_sequence_length * ddp_world_size); DistributedDataLoader train_loader(std::make_shared(FLAGS_input_bin, FLAGS_sequence_length), pp_world_size > 1 ? FLAGS_batch_size * num_micro_batches : FLAGS_batch_size, ddp_rank, ddp_world_size); @@ -217,9 +237,6 @@ void Train(const nn::parallel::Rank &rank) { tokenizer = std::make_unique(FLAGS_tokenizer_bin); } - // TODO(dcj): support more complex optimizer later - auto optimizer = optimizers::SGD(model->Parameters(), FLAGS_learning_rate); - auto train_iter = train_loader.begin(); std::shared_ptr loss_fn = (tp_world_size > 1) ? std::static_pointer_cast( @@ -228,17 +245,6 @@ void Train(const nn::parallel::Rank &rank) { loss_fn->To(device); LOG(INFO) << "Rank " << rank.GlobalRank() << ": start training"; - if (pp_world_size > 1) { - // NOTE(dcj): To ensure that the tensor shapes at the pipeline stage boundaries remain correct - // when sequence parallelism (SP) is enabled, we need to divide by sp_world_size. - auto shapes = std::vector>{ - {FLAGS_batch_size, FLAGS_sequence_length / sp_world_size, model_config.n_embd}}; - - model = std::make_shared(model, pp_world_size, num_micro_batches, shapes, - pp_rank, std::make_shared(optimizer), - rank.thread_rank()); - } - LOG(INFO) << "start training"; for (int step = 0; step < FLAGS_num_iteration + 1; ++step) { diff --git a/example/gpt2/net.cc b/example/gpt2/net.cc index 9ccbb48..69e4278 100644 --- a/example/gpt2/net.cc +++ b/example/gpt2/net.cc @@ -4,9 +4,12 @@ #include #include #include +#include +#include #include #include #include +#include #include "glog/logging.h" @@ -175,172 +178,20 @@ Block::Forward(const std::vector> &x) { return {x2}; } -GPT2::GPT2(const GPT2Config &config) : config_(config) { - int pp_size = nn::parallel::global::GetPipelineParallelSize(); - int vpp_size = nn::parallel::global::GetVirtualPipelineParallelSize(); - auto pp_rank = nn::parallel::pp_rank; - auto [is_first_stage, is_last_stage, layer_ranges_per_chunk] - = nn::parallel::PipelineParallel::GetStageInfo(config_.n_layer, pp_size, pp_rank, vpp_size); - - auto tp_world_size = nn::parallel::global::GetTensorParallelSize(); - - // NOTE(zbl): VocabParallelEmbedding requires vocab_size % tp_size == 0 - // Megatron-LM has an optional argument `--make-vocab-size-divisible-by`, would do padding to vocab - // Here we introduce padding by default, might need modify Tokenizer correspondingly later - CHECK_EQ(config.vocab_size % tp_world_size, 0) << "Vocab size should be divisible by TP world size"; - { - std::unordered_map> transformer; - if (is_first_stage) { - transformer[kWTELayerName] = std::make_shared( - config_.vocab_size, config_.n_embd, nn::parallel::global::GetSequenceParallelEnabled()); - transformer[kWPELayerName] = std::make_shared(config_.block_size, config_.n_embd); - } - - { - std::vector> h; - for (const auto &[start_layer, end_layer] : layer_ranges_per_chunk) { - for (int64_t i = start_layer; i < end_layer; ++i) { h.push_back(std::make_shared(config)); } - } - - transformer[kHLayerName] = std::make_shared(std::move(h)); - if (is_last_stage) { - transformer[kLnFLayerName] = std::make_shared(std::vector{config_.n_embd}); - } - - modules_[kTransformerLayerName] = std::make_shared(std::move(transformer)); - } - - if (is_last_stage) { - // don't init this one, we will tie weights - modules_[kLMHeadLayerName] = std::make_shared( - /*in_features=*/config_.n_embd, /*out_features=*/config_.vocab_size, - /*bias=*/false, - // NOTE(zbl): each tp_rank would get sharded [B, T, V_local] as logits - /*gather_output=*/false, - /*input_is_parallel=*/false, - /*skip_bias_add=*/false, - /*sequence_parallel=*/nn::parallel::global::GetSequenceParallelEnabled()); - } - - // FIXME(jym): Assigning the parameter values of wte to LMHead, which is not real tying operation - if (pp_size == 1) { - // https://paperswithcode.com/method/weight-tying - *mutable_module(kTransformerLayerName) - ->mutable_module(kWTELayerName) - ->mutable_parameter(nn::parallel::VocabParallelEmbedding::kParamWeightName) - = module(kLMHeadLayerName).parameter(nn::parallel::ColumnParallelLinear::kParamWeightName); - } - } -} - -std::vector> GPT2::BuildChunks(int pp_rank) { - int pp_size = nn::parallel::global::GetPipelineParallelSize(); - int vpp_size = nn::parallel::global::GetVirtualPipelineParallelSize(); - - auto [is_first_stage, is_last_stage, layer_ranges_per_chunk] - = nn::parallel::PipelineParallel::GetStageInfo(config_.n_layer, pp_size, pp_rank, vpp_size); - - std::vector> chunks; - chunks.reserve(layer_ranges_per_chunk.size()); - - int stage_layer_off = 0; - for (size_t i = 0; i < layer_ranges_per_chunk.size(); ++i) { - auto [start, end] = layer_ranges_per_chunk[i]; - int chunk_layers = end - start; - chunks.emplace_back(std::make_shared( - this, - stage_layer_off, // Starting offset for layer indexing - chunk_layers, // Number of layers in this chunk - (i == 0 && is_first_stage), // Whether to include embedding - (i == layer_ranges_per_chunk.size() - 1 && is_last_stage), // Whether to include lm_head - config_)); - stage_layer_off += chunk_layers; - } - - return chunks; +GPT2FirstStage::GPT2FirstStage(const GPT2Config &config) : config_(config) { + modules_[kWTELayerName] = std::make_shared( + config_.vocab_size, config_.n_embd, nn::parallel::global::GetSequenceParallelEnabled()); + modules_[kWPELayerName] = std::make_shared(config_.block_size, config_.n_embd); } std::vector> -GPT2Chunk::Forward(const std::vector> &input) { - auto transformer = parent_->mutable_module(GPT2::kTransformerLayerName); +GPT2FirstStage::Forward(const std::vector> &input) { // (B, T) auto x1 = input[0]; + CHECK_LE(x1->Dims()[1], config_.block_size) + << "Cannot forward sequence of length " << x1->Dims()[1] << ", block size is only " << config_.block_size; const auto device = x1->GetDevice(); - int pp_size = nn::parallel::global::GetPipelineParallelSize(); - int vpp_size = nn::parallel::global::GetVirtualPipelineParallelSize(); - - auto pp_group = nn::parallel::ProcessGroupFactory::Instance()->Get( - nn::parallel::GetPipelineParallelProcessGroupName(device->rank().GlobalRank())); - auto pp_rank = pp_group->GetGroupRank(device->rank().GlobalRank()); - - auto [is_first_stage, is_last_stage, layer_ranges_per_chunk] - = nn::parallel::PipelineParallel::GetStageInfo(config_.n_layer, pp_size, pp_rank, vpp_size); - - bool is_pipeline_first_chunk = is_first_stage && (layer_ranges_per_chunk[0].first == 0); - const auto t = x1->Dims()[1] - * (is_pipeline_first_chunk ? 1 : nn::parallel::global::GetSequenceParallelSize()); // full_seq_len - CHECK_LE(t, config_.block_size) << "Cannot forward sequence of length " << t << ", block size is only " - << config_.block_size; - - // forward the GPT2 model itself - if (has_embedding_) { - // (T_local) - // NOTE(zbl): Slice pos sequence when SP is enabled - auto tp_world_size = nn::parallel::global::GetTensorParallelSize(); - auto sequence_parallel_enabled = nn::parallel::global::GetSequenceParallelEnabled(); - int tp_rank = 0; - if (tp_world_size > 1) { - auto tp_group = nn::parallel::ProcessGroupFactory::Instance()->Get( - nn::parallel::GetTensorParallelProcessGroupName(device->rank().GlobalRank())); - tp_rank = tp_group->GetGroupRank(device->rank().GlobalRank()); - } - int64_t t_local = sequence_parallel_enabled ? (t / tp_world_size) : t; - int64_t start = sequence_parallel_enabled ? tp_rank * t_local : 0; - auto pos = nn::init::Arange(start, start + t_local, infini_train::DataType::kINT64, device); - - // (B, T) -> Embedding(V_local, C) -> (B, T, C) - auto tok_emb = transformer->mutable_module(GPT2::kWTELayerName)->Forward({x1})[0]; - - // (T) -> Embedding(T_max, C) -> (T, C) - auto pos_emb = transformer->mutable_module(GPT2::kWPELayerName)->Forward({pos})[0]; - // (B, T, C) - x1 = tok_emb + pos_emb; - } - - // (B, T, C) -> transformer -> (B, T, C) - auto h_modules = transformer->mutable_module(GPT2::kHLayerName); - CHECK_EQ(h_modules->type(), nn::ModuleList::kType) << "Failed to get ModuleList from transformer"; - auto blocks = std::dynamic_pointer_cast(h_modules); - // (bs, seq_len, n_embd) -> transformer -> (bs, seq_len, n_embd) - for (int i = 0; i < chunk_layers_; ++i) { x1 = (*blocks)[layer_begin_ + i]->Forward({x1})[0]; } - - // (B, T, C) -> Layernorm -> (B, T, C) - if (has_lm_head_) { - auto x2 = transformer->mutable_module(GPT2::kLnFLayerName)->Forward({x1}); - - // TODO(dcj): add inference-time mini-optimization - // (B, T, C) -> Linear(C, V) -> (B, T, V) - auto logits = parent_->mutable_module(GPT2::kLMHeadLayerName)->Forward(x2); - - return logits; - } - - return {x1}; -} - -std::vector> -GPT2::Forward(const std::vector> &x) { - // (B, T) - auto x1 = x[0]; - const auto device = x1->GetDevice(); - - const auto t = x1->Dims()[1]; // T - CHECK_LE(t, config_.block_size) << "Cannot forward sequence of length " << t << ", block size is only " - << config_.block_size; - // forward the GPT2 model itself - auto &transformer = modules_[kTransformerLayerName]; - // (T_local) // NOTE(zbl): Slice pos sequence when SP is enabled auto tp_world_size = nn::parallel::global::GetTensorParallelSize(); @@ -351,32 +202,125 @@ GPT2::Forward(const std::vector> &x) { nn::parallel::GetTensorParallelProcessGroupName(device->rank().GlobalRank())); tp_rank = tp_group->GetGroupRank(device->rank().GlobalRank()); } - int64_t t_local = sequence_parallel_enabled ? (t / tp_world_size) : t; + int64_t t_local = sequence_parallel_enabled ? x1->Dims()[1] / tp_world_size : x1->Dims()[1]; int64_t start = sequence_parallel_enabled ? tp_rank * t_local : 0; auto pos = nn::init::Arange(start, start + t_local, infini_train::DataType::kINT64, device); // (B, T) -> Embedding(V_local, C) -> (B, T, C) - auto tok_emb = transformer->mutable_module(kWTELayerName)->Forward({x1})[0]; + auto tok_emb = modules_[kWTELayerName]->Forward({x1})[0]; + // (T) -> Embedding(T_max, C) -> (T, C) - auto pos_emb = transformer->mutable_module(kWPELayerName)->Forward({pos})[0]; + auto pos_emb = modules_[kWPELayerName]->Forward({pos})[0]; // (B, T, C) - x1 = tok_emb + pos_emb; + return {tok_emb + pos_emb}; +} - // (B, T, C) -> transformer -> (B, T, C) - auto h_modules = transformer->mutable_module(kHLayerName); - CHECK_EQ(h_modules->type(), nn::ModuleList::kType) << "Failed to get ModuleList from transformer"; - auto h_layers = std::dynamic_pointer_cast(h_modules); +GPT2Chunk::GPT2Chunk(const GPT2Config &config, int start_layer, int end_layer) : config_(config) { + std::vector> h; + for (int64_t i = start_layer; i < end_layer; ++i) { + auto layer = std::make_shared(config); + h.push_back(layer); + } + modules_[kHLayerName] = std::make_shared(std::move(h)); +} + +std::vector> +GPT2Chunk::Forward(const std::vector> &x) { + auto x1 = x[0]; // (bs, seq_len, n_embd) -> transformer -> (bs, seq_len, n_embd) - for (auto &h : *h_layers) { x1 = h->Forward({x1})[0]; } + for (auto &h : *std::dynamic_pointer_cast(modules_[kHLayerName])) { x1 = h->Forward({x1})[0]; } + return {x1}; +} + +GPT2LastStage::GPT2LastStage(const GPT2Config &config) : config_(config) { + modules_[kLnFLayerName] = std::make_shared(std::vector{config_.n_embd}); + // don't init this one, we will tie weights + modules_[kLMHeadLayerName] = std::make_shared( + /*in_features=*/config_.n_embd, /*out_features=*/config_.vocab_size, + /*bias=*/false, + // NOTE(zbl): each tp_rank would get sharded [B, T, V_local] as logits + /*gather_output=*/false, + /*input_is_parallel=*/false, + /*skip_bias_add=*/false, + /*sequence_parallel=*/nn::parallel::global::GetSequenceParallelEnabled()); +} +std::vector> +GPT2LastStage::Forward(const std::vector> &x) { // (B, T, C) -> Layernorm -> (B, T, C) - auto x3 = transformer->mutable_module(kLnFLayerName)->Forward({x1}); + auto x1 = modules_[kLnFLayerName]->Forward(x); // TODO(dcj): add inference-time mini-optimization // (B, T, C) -> Linear(C, V) -> (B, T, V) - auto logits = modules_[kLMHeadLayerName]->Forward(x3); - // (B, T, V_original) - return logits; + return modules_[kLMHeadLayerName]->Forward(x1); +} + +GPT2::GPT2(const GPT2Config &config) + : config_(config), stage_info_(nn::parallel::PipelineParallel::GetStageInfo( + config_.n_layer, nn::parallel::global::GetPipelineParallelSize(), nn::parallel::pp_rank, + nn::parallel::global::GetVirtualPipelineParallelSize())) { + auto tp_world_size = nn::parallel::global::GetTensorParallelSize(); + + // NOTE(zbl): VocabParallelEmbedding requires vocab_size % tp_size == 0 + // Megatron-LM has an optional argument `--make-vocab-size-divisible-by`, would do padding to vocab + // Here we introduce padding by default, might need modify Tokenizer correspondingly later + CHECK_EQ(config.vocab_size % tp_world_size, 0) << "Vocab size should be divisible by TP world size"; + + std::unordered_map> transformer; + if (stage_info_.is_first_stage) { + modules_[kPPFirstStageName] = std::make_shared(config_); + transformer[GPT2FirstStage::kWTELayerName] + = modules_[kPPFirstStageName]->mutable_module(GPT2FirstStage::kWTELayerName); + transformer[GPT2FirstStage::kWPELayerName] + = modules_[kPPFirstStageName]->mutable_module(GPT2FirstStage::kWPELayerName); + } + + { + std::map>> start_layer_to_layer_size_and_chunk; + for (int chunk_idx = 0; chunk_idx < stage_info_.layer_ranges_per_chunk.size(); ++chunk_idx) { + const auto [start_layer, end_layer] = stage_info_.layer_ranges_per_chunk[chunk_idx]; + auto chunk = std::make_shared(config_, start_layer, end_layer); + start_layer_to_layer_size_and_chunk[start_layer] = std::make_pair(end_layer - start_layer, chunk); + } + std::vector> h; + int chunk_idx = 0; + for (auto &[start_layer, layer_size_and_chunk] : start_layer_to_layer_size_and_chunk) { + auto [layer_size, chunk] = layer_size_and_chunk; + for (int idx = 0; idx < layer_size; ++idx) { + h.push_back(chunk->mutable_module(GPT2Chunk::kHLayerName)->mutable_module(std::to_string(idx))); + } + modules_[kPPChunkNamePrefix + std::to_string(chunk_idx)] = std::move(chunk); + ++chunk_idx; + } + transformer[GPT2Chunk::kHLayerName] = std::make_shared(std::move(h)); + } + + if (stage_info_.is_last_stage) { + modules_[kPPLastStageName] = std::make_shared(config_); + transformer[GPT2LastStage::kLnFLayerName] + = modules_[kPPLastStageName]->mutable_module(GPT2LastStage::kLnFLayerName); + modules_[GPT2LastStage::kLMHeadLayerName] + = modules_[kPPLastStageName]->mutable_module(GPT2LastStage::kLMHeadLayerName); + } + modules_[kTransformerLayerName] = std::make_shared(std::move(transformer)); + + // FIXME(jym): Assigning the parameter values of wte to LMHead, which is not real tying operation + if (nn::parallel::global::GetPipelineParallelSize() == 1) { + // https://paperswithcode.com/method/weight-tying + *mutable_module(kTransformerLayerName) + ->mutable_module(GPT2FirstStage::kWTELayerName) + ->mutable_parameter(nn::parallel::VocabParallelEmbedding::kParamWeightName) + = module(GPT2LastStage::kLMHeadLayerName).parameter(nn::parallel::ColumnParallelLinear::kParamWeightName); + } +} + +std::vector> +GPT2::Forward(const std::vector> &x) { + auto x1 = modules_[kPPFirstStageName]->Forward(x); + for (int chunk_idx = 0; chunk_idx < stage_info_.layer_ranges_per_chunk.size(); ++chunk_idx) { + x1 = modules_[kPPChunkNamePrefix + std::to_string(chunk_idx)]->Forward(x1); + } + return modules_[kPPLastStageName]->Forward(x1); } std::shared_ptr GPT2::FromPretrained(ModelType model_type) { @@ -479,12 +423,12 @@ std::shared_ptr GPT2::FromLLMC(const std::string &filepath) { // local: (vocab_size_per_partition, n_embd) if (is_first_stage) { auto &transformer_wte_weight - = state_dict[std::format("{}.{}.{}", GPT2::kTransformerLayerName, GPT2::kWTELayerName, + = state_dict[std::format("{}.{}.{}", GPT2::kTransformerLayerName, GPT2FirstStage::kWTELayerName, nn::parallel::VocabParallelEmbedding::kParamWeightName)]; ReadMatrixRowShardFloat(ifs, static_cast(transformer_wte_weight->DataPtr()), model_vocab_size, n_embd, v_start, vpp); } else if (pp_size > 1 && is_last_stage) { - auto &lm_head_weight = state_dict[std::format("{}.{}", GPT2::kLMHeadLayerName, + auto &lm_head_weight = state_dict[std::format("{}.{}", GPT2LastStage::kLMHeadLayerName, nn::parallel::ColumnParallelLinear::kParamWeightName)]; ReadMatrixRowShardFloat(ifs, static_cast(lm_head_weight->DataPtr()), model_vocab_size, n_embd, v_start, vpp); @@ -500,8 +444,8 @@ std::shared_ptr GPT2::FromLLMC(const std::string &filepath) { if (is_first_stage) { // transformer.wpe.weight - auto &transformer_wpe_weight = state_dict[std::format("{}.{}.{}", GPT2::kTransformerLayerName, - GPT2::kWPELayerName, nn::Embedding::kParamWeightName)]; + auto &transformer_wpe_weight = state_dict[std::format( + "{}.{}.{}", GPT2::kTransformerLayerName, GPT2FirstStage::kWPELayerName, nn::Embedding::kParamWeightName)]; ReadMatrixAllFloat(ifs, static_cast(transformer_wpe_weight->DataPtr()), block_size, n_embd); } else { size_t wpe_bytes = block_size * n_embd * sizeof(float); @@ -512,7 +456,7 @@ std::shared_ptr GPT2::FromLLMC(const std::string &filepath) { int local_layer_index = 0; for (int idx = 0; idx < n_layer; ++idx) { if (owned_layers[idx]) { - auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, GPT2::kHLayerName, + auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, GPT2Chunk::kHLayerName, std::to_string(local_layer_index), Block::kLn1LayerName, nn::LayerNorm::kParamWeightName)]; ReadVectorAllFloat(ifs, static_cast(tensor->DataPtr()), n_embd); @@ -527,7 +471,7 @@ std::shared_ptr GPT2::FromLLMC(const std::string &filepath) { local_layer_index = 0; for (int idx = 0; idx < n_layer; ++idx) { if (owned_layers[idx]) { - auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, GPT2::kHLayerName, + auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, GPT2Chunk::kHLayerName, std::to_string(local_layer_index), Block::kLn1LayerName, nn::LayerNorm::kParamBiasName)]; ReadVectorAllFloat(ifs, static_cast(tensor->DataPtr()), n_embd); @@ -542,9 +486,9 @@ std::shared_ptr GPT2::FromLLMC(const std::string &filepath) { local_layer_index = 0; for (int idx = 0; idx < n_layer; ++idx) { if (owned_layers[idx]) { - auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, GPT2::kHLayerName, - std::to_string(local_layer_index), Block::kAttnLayerName, - CausalSelfAttention::kCAttnLayerName, + auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, + GPT2Chunk::kHLayerName, std::to_string(local_layer_index), + Block::kAttnLayerName, CausalSelfAttention::kCAttnLayerName, nn::parallel::ColumnParallelLinear::kParamWeightName)]; // NOTE(zbl): In the .bin model file, Q/K/V is concated along last dim, // i.e. [Q|K|V].T = [q1|q2|...|qn|k1|k2|...|kn|v1|v2|...|vn].T @@ -585,9 +529,9 @@ std::shared_ptr GPT2::FromLLMC(const std::string &filepath) { local_layer_index = 0; for (int idx = 0; idx < n_layer; ++idx) { if (owned_layers[idx]) { - auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, GPT2::kHLayerName, - std::to_string(local_layer_index), Block::kAttnLayerName, - CausalSelfAttention::kCAttnLayerName, + auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, + GPT2Chunk::kHLayerName, std::to_string(local_layer_index), + Block::kAttnLayerName, CausalSelfAttention::kCAttnLayerName, nn::parallel::ColumnParallelLinear::kParamBiasName)]; // NOTE(zbl): Same as c_attn.weight, the bias for Q/K/V is concated // i.e. [Q|K|V] = [q1|q2|...|qn|k1|k2|...|kn|v1|v2|...|vn] @@ -627,9 +571,9 @@ std::shared_ptr GPT2::FromLLMC(const std::string &filepath) { local_layer_index = 0; for (int idx = 0; idx < n_layer; ++idx) { if (owned_layers[idx]) { - auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, GPT2::kHLayerName, - std::to_string(local_layer_index), Block::kAttnLayerName, - CausalSelfAttention::kCProjLayerName, + auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, + GPT2Chunk::kHLayerName, std::to_string(local_layer_index), + Block::kAttnLayerName, CausalSelfAttention::kCProjLayerName, nn::parallel::RowParallelLinear::kParamWeightName)]; ReadMatrixColShardFloat(ifs, static_cast(tensor->DataPtr()), n_embd, n_embd, tp_rank * in_pp, in_pp); @@ -644,9 +588,9 @@ std::shared_ptr GPT2::FromLLMC(const std::string &filepath) { local_layer_index = 0; for (int idx = 0; idx < n_layer; ++idx) { if (owned_layers[idx]) { - auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, GPT2::kHLayerName, - std::to_string(local_layer_index), Block::kAttnLayerName, - CausalSelfAttention::kCProjLayerName, + auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, + GPT2Chunk::kHLayerName, std::to_string(local_layer_index), + Block::kAttnLayerName, CausalSelfAttention::kCProjLayerName, nn::parallel::RowParallelLinear::kParamBiasName)]; ReadVectorAllFloat(ifs, static_cast(tensor->DataPtr()), n_embd); ++local_layer_index; @@ -660,7 +604,7 @@ std::shared_ptr GPT2::FromLLMC(const std::string &filepath) { local_layer_index = 0; for (int idx = 0; idx < n_layer; ++idx) { if (owned_layers[idx]) { - auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, GPT2::kHLayerName, + auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, GPT2Chunk::kHLayerName, std::to_string(local_layer_index), Block::kLn2LayerName, nn::LayerNorm::kParamWeightName)]; ReadVectorAllFloat(ifs, static_cast(tensor->DataPtr()), n_embd); @@ -675,7 +619,7 @@ std::shared_ptr GPT2::FromLLMC(const std::string &filepath) { local_layer_index = 0; for (int idx = 0; idx < n_layer; ++idx) { if (owned_layers[idx]) { - auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, GPT2::kHLayerName, + auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, GPT2Chunk::kHLayerName, std::to_string(local_layer_index), Block::kLn2LayerName, nn::LayerNorm::kParamBiasName)]; ReadVectorAllFloat(ifs, static_cast(tensor->DataPtr()), n_embd); @@ -690,9 +634,10 @@ std::shared_ptr GPT2::FromLLMC(const std::string &filepath) { local_layer_index = 0; for (int idx = 0; idx < n_layer; ++idx) { if (owned_layers[idx]) { - auto &tensor = state_dict[std::format( - "{}.{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, GPT2::kHLayerName, std::to_string(local_layer_index), - Block::kMlpLayerName, MLP::kCFcLayerName, nn::parallel::ColumnParallelLinear::kParamWeightName)]; + auto &tensor + = state_dict[std::format("{}.{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, GPT2Chunk::kHLayerName, + std::to_string(local_layer_index), Block::kMlpLayerName, MLP::kCFcLayerName, + nn::parallel::ColumnParallelLinear::kParamWeightName)]; ReadMatrixRowShardFloat(ifs, static_cast(tensor->DataPtr()), fc_out, n_embd, fc_start, fc_pp); ++local_layer_index; } else { @@ -705,9 +650,10 @@ std::shared_ptr GPT2::FromLLMC(const std::string &filepath) { local_layer_index = 0; for (int idx = 0; idx < n_layer; ++idx) { if (owned_layers[idx]) { - auto &tensor = state_dict[std::format( - "{}.{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, GPT2::kHLayerName, std::to_string(local_layer_index), - Block::kMlpLayerName, MLP::kCFcLayerName, nn::parallel::ColumnParallelLinear::kParamBiasName)]; + auto &tensor + = state_dict[std::format("{}.{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, GPT2Chunk::kHLayerName, + std::to_string(local_layer_index), Block::kMlpLayerName, MLP::kCFcLayerName, + nn::parallel::ColumnParallelLinear::kParamBiasName)]; ReadVectorShardFloat(ifs, static_cast(tensor->DataPtr()), fc_out, fc_start, fc_pp); ++local_layer_index; } else { @@ -720,9 +666,10 @@ std::shared_ptr GPT2::FromLLMC(const std::string &filepath) { local_layer_index = 0; for (int idx = 0; idx < n_layer; ++idx) { if (owned_layers[idx]) { - auto &tensor = state_dict[std::format( - "{}.{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, GPT2::kHLayerName, std::to_string(local_layer_index), - Block::kMlpLayerName, MLP::kCProjLayerName, nn::parallel::RowParallelLinear::kParamWeightName)]; + auto &tensor + = state_dict[std::format("{}.{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, GPT2Chunk::kHLayerName, + std::to_string(local_layer_index), Block::kMlpLayerName, MLP::kCProjLayerName, + nn::parallel::RowParallelLinear::kParamWeightName)]; ReadMatrixColShardFloat(ifs, static_cast(tensor->DataPtr()), n_embd, fc_out, tp_rank * in4_pp, in4_pp); ++local_layer_index; @@ -736,9 +683,10 @@ std::shared_ptr GPT2::FromLLMC(const std::string &filepath) { local_layer_index = 0; for (int idx = 0; idx < n_layer; ++idx) { if (owned_layers[idx]) { - auto &tensor = state_dict[std::format( - "{}.{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, GPT2::kHLayerName, std::to_string(local_layer_index), - Block::kMlpLayerName, MLP::kCProjLayerName, nn::parallel::RowParallelLinear::kParamBiasName)]; + auto &tensor + = state_dict[std::format("{}.{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, GPT2Chunk::kHLayerName, + std::to_string(local_layer_index), Block::kMlpLayerName, MLP::kCProjLayerName, + nn::parallel::RowParallelLinear::kParamBiasName)]; ReadVectorAllFloat(ifs, static_cast(tensor->DataPtr()), n_embd); ++local_layer_index; } else { @@ -749,12 +697,12 @@ std::shared_ptr GPT2::FromLLMC(const std::string &filepath) { if (is_last_stage) { // transformer.ln_f.weight - auto &transformer_ln_f_weight = state_dict[std::format("{}.{}.{}", GPT2::kTransformerLayerName, - GPT2::kLnFLayerName, nn::LayerNorm::kParamWeightName)]; + auto &transformer_ln_f_weight = state_dict[std::format( + "{}.{}.{}", GPT2::kTransformerLayerName, GPT2LastStage::kLnFLayerName, nn::LayerNorm::kParamWeightName)]; ReadVectorAllFloat(ifs, static_cast(transformer_ln_f_weight->DataPtr()), n_embd); // transformer.ln_f.bias - auto &transformer_ln_f_bias = state_dict[std::format("{}.{}.{}", GPT2::kTransformerLayerName, - GPT2::kLnFLayerName, nn::LayerNorm::kParamBiasName)]; + auto &transformer_ln_f_bias = state_dict[std::format( + "{}.{}.{}", GPT2::kTransformerLayerName, GPT2LastStage::kLnFLayerName, nn::LayerNorm::kParamBiasName)]; ReadVectorAllFloat(ifs, static_cast(transformer_ln_f_bias->DataPtr()), n_embd); } else { size_t ln_f_w_bytes = n_embd * sizeof(float); @@ -763,3 +711,5 @@ std::shared_ptr GPT2::FromLLMC(const std::string &filepath) { } return local_gpt2; } + +int GPT2::GetChunkSize() const { return stage_info_.layer_ranges_per_chunk.size(); } diff --git a/example/gpt2/net.h b/example/gpt2/net.h index 24e4eae..f52e4d4 100644 --- a/example/gpt2/net.h +++ b/example/gpt2/net.h @@ -7,8 +7,9 @@ #include "glog/logging.h" -#include "infini_train/include/device.h" #include "infini_train/include/nn/modules/module.h" +#include "infini_train/include/nn/parallel/pp/pipeline_parallel.h" +#include "infini_train/include/nn/parallel/pp/pipeline_stage.h" #include "infini_train/include/tensor.h" struct GPT2Config { @@ -71,15 +72,51 @@ class Block : public infini_train::nn::CloneableModule { Forward(const std::vector> &x) override; }; -class GPT2 : public infini_train::nn::CloneableModule { +class GPT2FirstStage : public infini_train::nn::CloneableModule { public: static constexpr char kWTELayerName[] = "wte"; static constexpr char kWPELayerName[] = "wpe"; + + explicit GPT2FirstStage(const GPT2Config &config); + + std::vector> + Forward(const std::vector> &x) override; + +private: + const GPT2Config config_; +}; + +class GPT2Chunk : public infini_train::nn::CloneableModule { +public: static constexpr char kHLayerName[] = "h"; + + GPT2Chunk(const GPT2Config &config, int start_layer, int end_layer); + + std::vector> + Forward(const std::vector> &x) override; + +private: + const GPT2Config config_; +}; + +class GPT2LastStage : public infini_train::nn::CloneableModule { +public: static constexpr char kLnFLayerName[] = "ln_f"; - static constexpr char kTransformerLayerName[] = "transformer"; static constexpr char kLMHeadLayerName[] = "lm_head"; + explicit GPT2LastStage(const GPT2Config &config); + + std::vector> + Forward(const std::vector> &x) override; + +private: + const GPT2Config config_; +}; + +class GPT2 : public infini_train::nn::CloneableModule { +public: + static constexpr char kTransformerLayerName[] = "transformer"; + enum class ModelType : int8_t { kGPT2, kGPT2Medium, @@ -92,31 +129,12 @@ class GPT2 : public infini_train::nn::CloneableModule { std::vector> Forward(const std::vector> &x) override; - std::vector> BuildChunks(int pp_rank) override; - static std::shared_ptr FromPretrained(ModelType model_type); static std::shared_ptr FromLLMC(const std::string &filepath); -private: - GPT2Config config_; -}; - -class GPT2Chunk : public infini_train::nn::CloneableModule { -public: - GPT2Chunk(GPT2 *parent, int layer_begin, int chunk_layers, bool has_embedding, bool has_lm_head, - const GPT2Config &config) - : parent_(parent), layer_begin_(layer_begin), chunk_layers_(chunk_layers), has_embedding_(has_embedding), - has_lm_head_(has_lm_head), config_(config) {} - - std::vector> - Forward(const std::vector> &x) override; + int GetChunkSize() const; private: - GPT2 *parent_ = nullptr; - int layer_begin_ = 0; - int chunk_layers_ = 0; - bool has_embedding_ = false; - bool has_lm_head_ = false; - - GPT2Config config_; + const GPT2Config config_; + const infini_train::nn::parallel::StageInfo stage_info_; }; diff --git a/example/llama3/main.cc b/example/llama3/main.cc index e4a234e..fdea216 100644 --- a/example/llama3/main.cc +++ b/example/llama3/main.cc @@ -168,16 +168,35 @@ void Train(const nn::parallel::Rank &rank) { LOG(FATAL) << "Rank " << rank.GlobalRank() << ": Datatype " << FLAGS_dtype << " not supported."; } - // NOTE(dcj): Complete all device (.to(device)) and dtype (.to(dtype)) conversions - // before wrapping the model with DistributedDataParallel (DDP). - // Otherwise, DDP’s gradient hooks may be lost because new parameter tensors - // are created during the conversion. - if (ddp_world_size > 1) { + auto num_micro_batches = FLAGS_total_batch_size / (FLAGS_batch_size * FLAGS_sequence_length * ddp_world_size); + + // TODO(dcj): support more complex optimizer later + auto optimizer = optimizers::Adam(model->Parameters(), FLAGS_learning_rate); + + if (pp_world_size > 1) { + // NOTE(dcj): To ensure that the tensor shapes at the pipeline stage boundaries remain correct + // when sequence parallelism (SP) is enabled, we need to divide by sp_world_size. + auto shapes = std::vector>{ + {FLAGS_batch_size, FLAGS_sequence_length / sp_world_size, model_config.n_embd}}; + + model = std::make_shared( + model, pp_world_size, num_micro_batches, shapes, pp_rank, std::make_shared(optimizer), + rank.thread_rank(), std::dynamic_pointer_cast(model)->GetChunkSize()); + if (ddp_world_size > 1) { + auto *mutable_chunks = dynamic_cast(model.get())->mutable_chunks(); + for (int chunk_id = 0; chunk_id < mutable_chunks->size(); ++chunk_id) { + (*mutable_chunks)[chunk_id] + = std::make_shared(mutable_chunks->at(chunk_id), rank.thread_rank()); + } + } + } else if (ddp_world_size > 1) { + // NOTE(dcj): Complete all device (.to(device)) and dtype (.to(dtype)) conversions + // before wrapping the model with DistributedDataParallel (DDP). + // Otherwise, DDP’s gradient hooks may be lost because new parameter tensors + // are created during the conversion. model = std::make_shared(model, rank.thread_rank()); } - auto num_micro_batches = FLAGS_total_batch_size / (FLAGS_batch_size * FLAGS_sequence_length * ddp_world_size); - DistributedDataLoader train_loader(std::make_shared(FLAGS_input_bin, FLAGS_sequence_length), pp_world_size > 1 ? FLAGS_batch_size * num_micro_batches : FLAGS_batch_size, ddp_rank, ddp_world_size); @@ -197,9 +216,6 @@ void Train(const nn::parallel::Rank &rank) { tokenizer = std::make_unique(FLAGS_tokenizer_bin); } - // TODO(dcj): support more complex optimizer later - auto optimizer = optimizers::Adam(model->Parameters(), FLAGS_learning_rate); - auto train_iter = train_loader.begin(); std::shared_ptr loss_fn = (tp_world_size > 1) ? std::static_pointer_cast(std::make_shared()) @@ -207,17 +223,6 @@ void Train(const nn::parallel::Rank &rank) { loss_fn->To(device); LOG(INFO) << "Rank " << rank.GlobalRank() << ": start training"; - if (pp_world_size > 1) { - // NOTE(dcj): To ensure that the tensor shapes at the pipeline stage boundaries remain correct - // when sequence parallelism (SP) is enabled, we need to divide by sp_world_size. - auto shapes = std::vector>{ - {FLAGS_batch_size, FLAGS_sequence_length / sp_world_size, model_config.n_embd}}; - - model = std::make_shared(model, pp_world_size, num_micro_batches, shapes, - pp_rank, std::make_shared(optimizer), - rank.thread_rank()); - } - for (int step = 0; step < FLAGS_num_iteration + 1; ++step) { const bool last_step = step == FLAGS_num_iteration; diff --git a/example/llama3/net.cc b/example/llama3/net.cc index 67c9548..a70a811 100644 --- a/example/llama3/net.cc +++ b/example/llama3/net.cc @@ -4,6 +4,7 @@ #include #include #include +#include #include #include #include @@ -24,7 +25,6 @@ #include "infini_train/include/nn/parallel/global.h" #include "infini_train/include/nn/parallel/pp/pipeline_parallel.h" #include "infini_train/include/nn/parallel/tensor_parallel.h" -#include "infini_train/include/nn/parallel/utils.h" #include "infini_train/include/tensor.h" using namespace infini_train; @@ -325,137 +325,27 @@ std::vector> Block::Forward(const std::vector> transformer; - if (is_first_stage) { - transformer[kWTELayerName] = std::make_shared( - config.vocab_size, config.n_embd, nn::parallel::global::GetSequenceParallelEnabled()); - } - - std::vector> h_local; - for (const auto &[start_layer, end_layer] : layer_ranges_per_chunk) { - for (int64_t i = start_layer; i < end_layer; ++i) { h_local.push_back(std::make_shared(config)); } - } - transformer[kHLayerName] = std::make_shared(std::move(h_local)); - - if (is_last_stage) { - transformer[kLnFLayerName] = std::make_shared(config.n_embd, config.norm_eps); - // NOTE(zbl): weight-tying is possible but torch script did not do so - modules_[kLMHeadLayerName] = std::make_shared( - /*in_features=*/config_.n_embd, /*out_features=*/config_.vocab_size, - /*bias=*/false, - // NOTE(zbl): each rank would get sharded [B, T, V_local] as logits - /*gather_output=*/false, - /*input_is_parallel=*/false, - /*skip_bias_add=*/false, - /*sequence_parallel=*/nn::parallel::global::GetSequenceParallelEnabled()); - } - modules_[kTransformerLayerName] = std::make_shared(std::move(transformer)); +LLaMA3FirstStage::LLaMA3FirstStage(const LLaMA3Config &config) : config_(config) { + modules_[LLaMA3FirstStage::kWTELayerName] = std::make_shared( + config.vocab_size, config.n_embd, nn::parallel::global::GetSequenceParallelEnabled()); } -std::vector> LLaMA3::BuildChunks(int pp_rank) { - int pp_size = nn::parallel::global::GetPipelineParallelSize(); - int vpp_size = nn::parallel::global::GetVirtualPipelineParallelSize(); - - auto [is_first_stage, is_last_stage, layer_ranges_per_chunk] - = nn::parallel::PipelineParallel::GetStageInfo(config_.n_layer, pp_size, pp_rank, vpp_size); - - std::vector> chunks; - chunks.reserve(layer_ranges_per_chunk.size()); - - int stage_layer_off = 0; - for (size_t i = 0; i < layer_ranges_per_chunk.size(); ++i) { - auto [start, end] = layer_ranges_per_chunk[i]; - - int chunk_layers = end - start; - chunks.emplace_back(std::make_shared( - this, - stage_layer_off, // Starting offset for layer indexing - chunk_layers, // Number of layers in this chunk - (i == 0 && is_first_stage), // Whether to include embedding - (i == layer_ranges_per_chunk.size() - 1 && is_last_stage), // Whether to include lm_head - config_)); - stage_layer_off += chunk_layers; - } - - return chunks; +std::vector> LLaMA3FirstStage::Forward(const std::vector> &x) { + return modules_[LLaMA3FirstStage::kWTELayerName]->Forward(x); } -std::vector> LLaMA3Chunk::Forward(const std::vector> &input) { - auto transformer = parent_->mutable_module(LLaMA3::kTransformerLayerName); - auto x1 = input[0]; - const auto device = x1->GetDevice(); - - int pp_size = nn::parallel::global::GetPipelineParallelSize(); - int vpp_size = nn::parallel::global::GetVirtualPipelineParallelSize(); - auto pp_group = nn::parallel::ProcessGroupFactory::Instance()->Get( - nn::parallel::GetPipelineParallelProcessGroupName(device->rank().GlobalRank())); - auto pp_rank = pp_group->GetGroupRank(device->rank().GlobalRank()); - - auto [is_first_stage, is_last_stage, layer_ranges_per_chunk] - = nn::parallel::PipelineParallel::GetStageInfo(config_.n_layer, pp_size, pp_rank, vpp_size); - - bool is_pipeline_first_chunk = is_first_stage && (layer_ranges_per_chunk[0].first == 0); - const auto t = x1->Dims()[1] - * (is_pipeline_first_chunk ? 1 : nn::parallel::global::GetSequenceParallelSize()); // full_seq_len - - if (has_embedding_) { - x1 = transformer->mutable_module(LLaMA3::kWTELayerName)->Forward({x1})[0]; - } - - // Init freqs_cis on device only once - // TODO(zbl): consider moving this to model construction - if (buffers_[LLaMA3::kFreqsCisName] == nullptr) { - buffers_[LLaMA3::kFreqsCisName] = PrecomputeFreqsCis(config_.n_embd / config_.n_head, config_.block_size * 2, - config_.rope_theta, config_.use_scaled_rope, device); +LLaMA3Chunk::LLaMA3Chunk(const LLaMA3Config &config, int start_layer, int end_layer) : config_(config) { + std::vector> h; + for (int64_t i = start_layer; i < end_layer; ++i) { + auto layer = std::make_shared(config); + h.push_back(layer); } - - // TODO(zbl): dynamic start_pos - int64_t start_pos = 0; - auto freqs_view = buffers_[LLaMA3::kFreqsCisName]->Slice(0, start_pos, start_pos + t, 1); - - // TODO(lzm): add dtype support for nn::function::Ones later - std::shared_ptr ones = std::make_shared(nn::function::Ones({t, t})->To(x1->GetDevice())); - std::shared_ptr mask = nn::function::Triu(ones, 1)->View({1, 1, t, t}); - - std::shared_ptr start_pos_ptr = nullptr; - - auto h_modules = transformer->mutable_module(LLaMA3::kHLayerName); - CHECK_EQ(h_modules->type(), nn::ModuleList::kType) << "Failed to get ModuleList from transformer"; - auto blocks = std::dynamic_pointer_cast(h_modules); - // (bs, seq_len, n_embd) -> transformer -> (bs, seq_len, n_embd) - for (int i = 0; i < chunk_layers_; ++i) { - x1 = (*blocks)[layer_begin_ + i]->Forward({x1, freqs_view, start_pos_ptr, mask})[0]; - } - - if (has_lm_head_) { - // (bs, seq_len, n_embd) -> RMSNorm -> (bs, seq_len, n_embd) - auto x2 = transformer->mutable_module(LLaMA3::kLnFLayerName)->Forward({x1}); - - // TODO(zbl): add inference-time mini-optimization - // (bs, seq_len, n_embd) -> Linear(n_embd, vocab_size) -> (bs, seq_len, vocab_size) - auto logits = parent_->mutable_module(LLaMA3::kLMHeadLayerName)->Forward(x2); - - return logits; - } - - return {x1}; + modules_[LLaMA3Chunk::kHLayerName] = std::make_shared(std::move(h)); } -std::vector> LLaMA3::Forward(const std::vector> &x) { - // (bs, seq_len) +std::vector> LLaMA3Chunk::Forward(const std::vector> &x) { auto x1 = x[0]; const auto device = x1->GetDevice(); - const auto t = x1->Dims()[1]; // full_seq_len - CHECK_LE(t, config_.block_size) << "Cannot forward sequence of length " << t << ", block size is only " - << config_.block_size; - // Init freqs_cis on device only once // TODO(zbl): consider moving this to model construction if (buffers_[kFreqsCisName] == nullptr) { @@ -463,11 +353,8 @@ std::vector> LLaMA3::Forward(const std::vector Embedding(vocab_size, n_embd) -> (bs, seq_len, n_embd) - x1 = transformer->mutable_module(kWTELayerName)->Forward({x1})[0]; + // TODO(dcj): check if this shape is correct + const auto t = x1->Dims()[1] * nn::parallel::global::GetSequenceParallelSize(); // full_seq_len // TODO(zbl): dynamic start_pos int64_t start_pos = 0; @@ -479,21 +366,84 @@ std::vector> LLaMA3::Forward(const std::vector start_pos_ptr = nullptr; - auto h_modules = transformer->mutable_module(kHLayerName); - CHECK_EQ(h_modules->type(), nn::ModuleList::kType) << "Failed to get ModuleList from transformer"; - auto h_layers = std::dynamic_pointer_cast(h_modules); // (bs, seq_len, n_embd) -> transformer -> (bs, seq_len, n_embd) - for (auto &h : *h_layers) { x1 = h->Forward({x1, freqs_view, start_pos_ptr, mask})[0]; } + for (auto &h : *std::dynamic_pointer_cast(modules_[LLaMA3Chunk::kHLayerName])) { + x1 = h->Forward({x1, freqs_view, start_pos_ptr, mask})[0]; + } + return {x1}; +} + +LLaMA3LastStage::LLaMA3LastStage(const LLaMA3Config &config) : config_(config) { + modules_[kLnFLayerName] = std::make_shared(config.n_embd, config.norm_eps); + // NOTE(zbl): weight-tying is possible but torch script did not do so + modules_[kLMHeadLayerName] = std::make_shared( + /*in_features=*/config_.n_embd, /*out_features=*/config_.vocab_size, + /*bias=*/false, + // NOTE(zbl): each rank would get sharded [B, T, V_local] as logits + /*gather_output=*/false, + /*input_is_parallel=*/false, + /*skip_bias_add=*/false, + /*sequence_parallel=*/nn::parallel::global::GetSequenceParallelEnabled()); +} +std::vector> LLaMA3LastStage::Forward(const std::vector> &x) { // (bs, seq_len, n_embd) -> RMSNorm -> (bs, seq_len, n_embd) - auto x2 = transformer->mutable_module(kLnFLayerName)->Forward({x1}); + auto x1 = modules_[kLnFLayerName]->Forward(x); // TODO(zbl): add inference-time mini-optimization // (bs, seq_len, n_embd) -> Linear(n_embd, vocab_size) -> (bs, seq_len, vocab_size) - auto logits = modules_[kLMHeadLayerName]->Forward(x2); + return modules_[kLMHeadLayerName]->Forward(x1); +} + +LLaMA3::LLaMA3(const LLaMA3Config &config) + : config_(config), stage_info_(nn::parallel::PipelineParallel::GetStageInfo( + config_.n_layer, nn::parallel::global::GetPipelineParallelSize(), nn::parallel::pp_rank, + nn::parallel::global::GetVirtualPipelineParallelSize())) { + std::unordered_map> transformer; + if (stage_info_.is_first_stage) { + modules_[kPPFirstStageName] = std::make_shared(config_); + transformer[LLaMA3FirstStage::LLaMA3FirstStage::kWTELayerName] + = modules_[kPPFirstStageName]->mutable_module(LLaMA3FirstStage::LLaMA3FirstStage::kWTELayerName); + } + + { + std::map>> start_layer_to_layer_size_and_chunk; + for (int chunk_idx = 0; chunk_idx < stage_info_.layer_ranges_per_chunk.size(); ++chunk_idx) { + const auto [start_layer, end_layer] = stage_info_.layer_ranges_per_chunk[chunk_idx]; + auto chunk = std::make_shared(config_, start_layer, end_layer); + start_layer_to_layer_size_and_chunk[start_layer] = std::make_pair(end_layer - start_layer, chunk); + } + std::vector> h; + int chunk_idx = 0; + for (auto &[start_layer, layer_size_and_chunk] : start_layer_to_layer_size_and_chunk) { + auto [layer_size, chunk] = layer_size_and_chunk; + for (int idx = 0; idx < layer_size; ++idx) { + h.push_back( + chunk->mutable_module(LLaMA3Chunk::LLaMA3Chunk::kHLayerName)->mutable_module(std::to_string(idx))); + } + modules_[kPPChunkNamePrefix + std::to_string(chunk_idx)] = std::move(chunk); + ++chunk_idx; + } + transformer[LLaMA3Chunk::LLaMA3Chunk::kHLayerName] = std::make_shared(std::move(h)); + } + + if (stage_info_.is_last_stage) { + modules_[kPPLastStageName] = std::make_shared(config_); + transformer[LLaMA3LastStage::kLnFLayerName] + = modules_[kPPLastStageName]->mutable_module(LLaMA3LastStage::kLnFLayerName); + // NOTE(zbl): weight-tying is possible but torch script did not do so + modules_[LLaMA3LastStage::kLMHeadLayerName] + = modules_[kPPLastStageName]->mutable_module(LLaMA3LastStage::kLMHeadLayerName); + } + modules_[kTransformerLayerName] = std::make_shared(std::move(transformer)); +} - // (bs, seq_len, vocab_size) - return logits; +std::vector> LLaMA3::Forward(const std::vector> &x) { + auto x1 = modules_[kPPFirstStageName]->Forward({x[0]}); + for (int chunk_idx = 0; chunk_idx < stage_info_.layer_ranges_per_chunk.size(); ++chunk_idx) { + x1 = modules_[kPPChunkNamePrefix + std::to_string(chunk_idx)]->Forward(x1); + } + return modules_[kPPLastStageName]->Forward(x1); } std::shared_ptr LLaMA3::FromPretrained(ModelType model_type) { @@ -634,7 +584,7 @@ std::shared_ptr LLaMA3::FromLLMC(const std::string &filepath) { // ========== Read Sharded Params ========== // transformer.wte.weight : (vocab_size, n_embd) -> local tp_rank: rows of [v_start : v_start+vpp) if (is_first_stage) { - auto &wte = state_dict[std::format("{}.{}.{}", kTransformerLayerName, kWTELayerName, + auto &wte = state_dict[std::format("{}.{}.{}", kTransformerLayerName, LLaMA3FirstStage::kWTELayerName, nn::parallel::VocabParallelEmbedding::kParamWeightName)]; ReadMatrixRowShardFloat(ifs, static_cast(wte->DataPtr()), /*rows=*/vocab_size, /*cols=*/n_embd, @@ -648,7 +598,7 @@ std::shared_ptr LLaMA3::FromLLMC(const std::string &filepath) { int local_layer_index = 0; for (int i = 0; i < static_cast(n_layer); ++i) { if (owned_layers[i]) { - auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}", kTransformerLayerName, kHLayerName, + auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}", kTransformerLayerName, LLaMA3Chunk::kHLayerName, std::to_string(local_layer_index), Block::kLn1LayerName, RMSNorm::kParamWeightName)]; ReadVectorAllFloat(ifs, static_cast(tensor->DataPtr()), n_embd); @@ -664,7 +614,7 @@ std::shared_ptr LLaMA3::FromLLMC(const std::string &filepath) { local_layer_index = 0; for (int i = 0; i < static_cast(n_layer); ++i) { if (owned_layers[i]) { - auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}.{}", kTransformerLayerName, kHLayerName, + auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}.{}", kTransformerLayerName, LLaMA3Chunk::kHLayerName, std::to_string(local_layer_index), Block::kAttnLayerName, CausalSelfAttention::kCAttnLayerName, nn::parallel::ColumnParallelLinear::kParamWeightName)]; @@ -704,7 +654,7 @@ std::shared_ptr LLaMA3::FromLLMC(const std::string &filepath) { local_layer_index = 0; for (int i = 0; i < static_cast(n_layer); ++i) { if (owned_layers[i]) { - auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}.{}", kTransformerLayerName, kHLayerName, + auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}.{}", kTransformerLayerName, LLaMA3Chunk::kHLayerName, std::to_string(local_layer_index), Block::kAttnLayerName, CausalSelfAttention::kCProjLayerName, nn::parallel::RowParallelLinear::kParamWeightName)]; @@ -721,9 +671,8 @@ std::shared_ptr LLaMA3::FromLLMC(const std::string &filepath) { // transformer.h.{i}.ln_2.weight : Full version RMSNorm local_layer_index = 0; for (int i = 0; i < static_cast(n_layer); ++i) { - if (owned_layers[i]) { - auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}", kTransformerLayerName, kHLayerName, + auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}", kTransformerLayerName, LLaMA3Chunk::kHLayerName, std::to_string(local_layer_index), Block::kLn2LayerName, RMSNorm::kParamWeightName)]; ReadVectorAllFloat(ifs, static_cast(tensor->DataPtr()), n_embd); @@ -737,10 +686,9 @@ std::shared_ptr LLaMA3::FromLLMC(const std::string &filepath) { // transformer.h.{i}.mlp.c_fc.weight : ColumnParallelLinear, but actually applies on "rows" local_layer_index = 0; for (int i = 0; i < static_cast(n_layer); ++i) { - if (owned_layers[i]) { auto &tensor = state_dict[std::format( - "{}.{}.{}.{}.{}.{}", kTransformerLayerName, kHLayerName, std::to_string(local_layer_index), + "{}.{}.{}.{}.{}.{}", kTransformerLayerName, LLaMA3Chunk::kHLayerName, std::to_string(local_layer_index), Block::kMlpLayerName, MLP::kCFcLayerName, nn::parallel::ColumnParallelLinear::kParamWeightName)]; ReadMatrixRowShardFloat(ifs, static_cast(tensor->DataPtr()), /*rows=*/fc_out, /*cols=*/n_embd, @@ -755,10 +703,9 @@ std::shared_ptr LLaMA3::FromLLMC(const std::string &filepath) { // transformer.h.{i}.mlp.c_fc2.weight : ColumnParallelLinear, but actually applies on "rows" local_layer_index = 0; for (int i = 0; i < static_cast(n_layer); ++i) { - if (owned_layers[i]) { auto &tensor = state_dict[std::format( - "{}.{}.{}.{}.{}.{}", kTransformerLayerName, kHLayerName, std::to_string(local_layer_index), + "{}.{}.{}.{}.{}.{}", kTransformerLayerName, LLaMA3Chunk::kHLayerName, std::to_string(local_layer_index), Block::kMlpLayerName, MLP::kCFc2LayerName, nn::parallel::ColumnParallelLinear::kParamWeightName)]; ReadMatrixRowShardFloat(ifs, static_cast(tensor->DataPtr()), /*rows=*/fc_out, /*cols=*/n_embd, @@ -773,10 +720,9 @@ std::shared_ptr LLaMA3::FromLLMC(const std::string &filepath) { // transformer.h.{i}.mlp.c_proj.weight : RowParallelLinear, but actually applies on "columns" local_layer_index = 0; for (int i = 0; i < static_cast(n_layer); ++i) { - if (owned_layers[i]) { auto &tensor = state_dict[std::format( - "{}.{}.{}.{}.{}.{}", kTransformerLayerName, kHLayerName, std::to_string(local_layer_index), + "{}.{}.{}.{}.{}.{}", kTransformerLayerName, LLaMA3Chunk::kHLayerName, std::to_string(local_layer_index), Block::kMlpLayerName, MLP::kCProjLayerName, nn::parallel::RowParallelLinear::kParamWeightName)]; ReadMatrixColShardFloat(ifs, static_cast(tensor->DataPtr()), /*rows=*/n_embd, /*cols=*/fc_out, @@ -792,9 +738,9 @@ std::shared_ptr LLaMA3::FromLLMC(const std::string &filepath) { // lm_head.weight : (vocab_size, n_embd) -> ColumnParallelLinear, but actually applies on "rows" { if (is_last_stage) { - auto &ln_f - = state_dict[std::format("{}.{}.{}", kTransformerLayerName, kLnFLayerName, RMSNorm::kParamWeightName)]; - auto &lm_head = state_dict[std::format("{}.{}", kLMHeadLayerName, + auto &ln_f = state_dict[std::format("{}.{}.{}", kTransformerLayerName, LLaMA3LastStage::kLnFLayerName, + RMSNorm::kParamWeightName)]; + auto &lm_head = state_dict[std::format("{}.{}", LLaMA3LastStage::kLMHeadLayerName, nn::parallel::ColumnParallelLinear::kParamWeightName)]; ReadVectorAllFloat(ifs, static_cast(ln_f->DataPtr()), n_embd); ReadMatrixRowShardFloat(ifs, static_cast(lm_head->DataPtr()), diff --git a/example/llama3/net.h b/example/llama3/net.h index 7bc5df6..9bd7f9d 100644 --- a/example/llama3/net.h +++ b/example/llama3/net.h @@ -10,6 +10,7 @@ #include "infini_train/include/device.h" #include "infini_train/include/nn/modules/module.h" +#include "infini_train/include/nn/parallel/pp/pipeline_parallel.h" #include "infini_train/include/tensor.h" struct LLaMA3Config { @@ -108,15 +109,50 @@ class Block : public infini_train::nn::CloneableModule { Forward(const std::vector> &x) override; }; -class LLaMA3 : public infini_train::nn::CloneableModule { +class LLaMA3FirstStage : public infini_train::nn::CloneableModule { public: static constexpr char kWTELayerName[] = "wte"; + + explicit LLaMA3FirstStage(const LLaMA3Config &config); + + std::vector> + Forward(const std::vector> &x) override; + +private: + const LLaMA3Config config_; +}; + +class LLaMA3Chunk : public infini_train::nn::CloneableModule { +public: static constexpr char kHLayerName[] = "h"; + static constexpr char kFreqsCisName[] = "freqs_cis"; + + LLaMA3Chunk(const LLaMA3Config &config, int start_layer, int end_layer); + + std::vector> + Forward(const std::vector> &x) override; + +private: + const LLaMA3Config config_; +}; + +class LLaMA3LastStage : public infini_train::nn::CloneableModule { +public: static constexpr char kLnFLayerName[] = "ln_f"; - static constexpr char kTransformerLayerName[] = "transformer"; static constexpr char kLMHeadLayerName[] = "lm_head"; - static constexpr char kFreqsCisName[] = "freqs_cis"; + explicit LLaMA3LastStage(const LLaMA3Config &config); + + std::vector> + Forward(const std::vector> &x) override; + +private: + const LLaMA3Config config_; +}; + +class LLaMA3 : public infini_train::nn::CloneableModule { +public: + static constexpr char kTransformerLayerName[] = "transformer"; enum class ModelType : int8_t { // TODO(zbl): more model type from huggingface @@ -132,31 +168,12 @@ class LLaMA3 : public infini_train::nn::CloneableModule { std::vector> Forward(const std::vector> &x) override; - std::vector> BuildChunks(int pp_rank) override; - static std::shared_ptr FromPretrained(ModelType model_type); static std::shared_ptr FromLLMC(const std::string &filepath); -private: - LLaMA3Config config_; -}; - -class LLaMA3Chunk : public infini_train::nn::CloneableModule { -public: - LLaMA3Chunk(LLaMA3 *parent, int layer_begin, int chunk_layers, bool has_embedding, bool has_lm_head, - const LLaMA3Config &config) - : parent_(parent), layer_begin_(layer_begin), chunk_layers_(chunk_layers), has_embedding_(has_embedding), - has_lm_head_(has_lm_head), config_(config){}; - - std::vector> - Forward(const std::vector> &x) override; + int GetChunkSize() const { return stage_info_.layer_ranges_per_chunk.size(); } private: - LLaMA3 *parent_ = nullptr; - int layer_begin_ = 0; - int chunk_layers_ = 0; - bool has_embedding_ = false; - bool has_lm_head_ = false; - - LLaMA3Config config_; + const LLaMA3Config config_; + const infini_train::nn::parallel::StageInfo stage_info_; }; diff --git a/infini_train/include/nn/modules/module.h b/infini_train/include/nn/modules/module.h index 2223851..9bc78bc 100644 --- a/infini_train/include/nn/modules/module.h +++ b/infini_train/include/nn/modules/module.h @@ -25,6 +25,10 @@ class Module : public std::enable_shared_from_this { public: static constexpr char kUndefinedType[] = "Undefined"; + static constexpr char kPPFirstStageName[] = "__pp_first_stage"; + static constexpr char kPPLastStageName[] = "__pp_last_stage"; + static constexpr char kPPChunkNamePrefix[] = "__pp_chunk_"; + explicit Module(); explicit Module(const std::string &type); Module(const Module &) = default; @@ -54,8 +58,6 @@ class Module : public std::enable_shared_from_this { return 0.0f; }; - virtual std::vector> BuildChunks(int pp_rank); - virtual void To(const Device *device); virtual void To(DataType dtype); diff --git a/infini_train/include/nn/parallel/distributed_data_parallel.h b/infini_train/include/nn/parallel/distributed_data_parallel.h index 62adff0..6001a17 100644 --- a/infini_train/include/nn/parallel/distributed_data_parallel.h +++ b/infini_train/include/nn/parallel/distributed_data_parallel.h @@ -19,8 +19,6 @@ class DistributedDataParallel : public nn::Module { std::vector> Forward(const std::vector> &input_tensors) override; - std::vector> BuildChunks(int pp_rank) override; - private: std::shared_ptr reducer_ = nullptr; }; diff --git a/infini_train/include/nn/parallel/pp/pipeline_parallel.h b/infini_train/include/nn/parallel/pp/pipeline_parallel.h index de89b01..2ddd091 100644 --- a/infini_train/include/nn/parallel/pp/pipeline_parallel.h +++ b/infini_train/include/nn/parallel/pp/pipeline_parallel.h @@ -31,7 +31,7 @@ class PipelineParallel : public Module { public: PipelineParallel(const std::shared_ptr module, int num_stages, int num_micro_batches, const std::vector> &recv_shape, int rank, - const std::shared_ptr &optimizer, int device_id); + const std::shared_ptr &optimizer, int device_id, int vpp); float TrainStep(const std::vector> &input, const std::vector> &target, const std::shared_ptr &loss_fn, @@ -39,16 +39,18 @@ class PipelineParallel : public Module { static StageInfo GetStageInfo(int total_layers, int pp_size, int pp_rank, int chunks_per_stage = 1); + std::vector> *mutable_chunks(); + private: + void BuildPipelineStage(const std::shared_ptr &optimizer, + const std::vector> &recv_shape, int device_id, + std::vector> &&chunks); + + void SetupSchedule(int num_micro_batches); + int num_stages_ = -1; int rank_ = -1; - std::shared_ptr pipeline_stage_ = nullptr; std::shared_ptr schedule_ = nullptr; - - void BuildPipelineStage(const std::shared_ptr &model, const std::shared_ptr &optimizer, - const std::vector> &recv_shape, int device_id); - - void SetupSchedule(int num_micro_batches); + std::shared_ptr pipeline_stage_ = nullptr; }; - } // namespace infini_train::nn::parallel diff --git a/infini_train/include/nn/parallel/pp/pipeline_stage.h b/infini_train/include/nn/parallel/pp/pipeline_stage.h index 019a4af..80cd034 100644 --- a/infini_train/include/nn/parallel/pp/pipeline_stage.h +++ b/infini_train/include/nn/parallel/pp/pipeline_stage.h @@ -16,22 +16,25 @@ namespace infini_train::nn::parallel { class PipelineStage { public: - PipelineStage(const std::shared_ptr &model, int stage_index, int num_stages, - const std::vector> &recv_shape, std::shared_ptr optimizer, - int device_id); + PipelineStage(int stage_index, int num_stages, const std::vector> &recv_shape, + std::shared_ptr optimizer, int device_id, std::vector> &&chunks); std::vector> ForwardOneChunk(const std::vector> &inputs, int local_chunk_idx = 0); bool IsFirstStage() const { return stage_index_ == 0; } bool IsLastStage() const { return stage_index_ == num_stages_ - 1; } + int stage_index() const { return stage_index_; } int prev_rank() const { return prev_rank_; } int next_rank() const { return next_rank_; } int num_stages() const { return num_stages_; } + const Device *device() const { return device_; } const std::vector> &recv_shape() const { return recv_shape_; } std::shared_ptr optimizer() { return optimizer_; } + const auto &chunks() { return chunks_; } + auto *mutable_chunks() { return &chunks_; } private: int stage_index_ = -1; @@ -39,7 +42,6 @@ class PipelineStage { int prev_rank_ = -1; int next_rank_ = -1; const Device *device_ = nullptr; - std::shared_ptr model_ = nullptr; std::vector> chunks_; std::shared_ptr optimizer_ = nullptr; std::vector> recv_shape_; diff --git a/infini_train/src/nn/modules/module.cc b/infini_train/src/nn/modules/module.cc index e1684c1..4e0c6a2 100644 --- a/infini_train/src/nn/modules/module.cc +++ b/infini_train/src/nn/modules/module.cc @@ -20,10 +20,22 @@ const std::string &Module::type() const { return type_; } std::vector> Module::Parameters() const { std::vector> params; - for (auto &[_, param] : parameters_) { params.push_back(param); } - for (auto &[_, module] : modules_) { - for (auto ¶m : module->Parameters()) { params.push_back(param); } + std::unordered_set visited; + + auto AddIfUnvisited = [&](const std::shared_ptr ¶m) { + if (visited.insert(param.get()).second) { + params.push_back(param); + } + }; + + // Add parameters of this module + for (const auto &[_, param] : parameters_) { AddIfUnvisited(param); } + + // Recursively add parameters of submodules + for (const auto &[_, module] : modules_) { + for (const auto ¶m : module->Parameters()) { AddIfUnvisited(param); } } + return params; } @@ -100,6 +112,9 @@ std::unordered_map> Module::StateDict() con for (auto &[name, param] : parameters_) { state.emplace(name, param); } for (auto &[name, buffer] : buffers_) { state.emplace(name, buffer); } for (auto &[name, module] : modules_) { + if (name.starts_with("__pp")) { + continue; + } for (auto &[sub_name, param] : module->StateDict()) { state.emplace(name + "." + sub_name, param); } } return state; @@ -110,11 +125,6 @@ std::vector> Module::Forward(const std::vector> Module::BuildChunks(int pp_rank) { - LOG(FATAL) << "BuildChunks function not implemented for this module"; - return {}; -} - void Module::To(const Device *device) { CHECK_NOTNULL(device); if (device == device_) { diff --git a/infini_train/src/nn/parallel/distributed_data_parallel.cc b/infini_train/src/nn/parallel/distributed_data_parallel.cc index f81565f..a25a7d1 100644 --- a/infini_train/src/nn/parallel/distributed_data_parallel.cc +++ b/infini_train/src/nn/parallel/distributed_data_parallel.cc @@ -56,8 +56,4 @@ DistributedDataParallel::Forward(const std::vector> &inp } return outputs; } - -std::vector> DistributedDataParallel::BuildChunks(int pp_rank) { - return modules_[kModuleName]->BuildChunks(pp_rank); -} } // namespace infini_train::nn::parallel diff --git a/infini_train/src/nn/parallel/pp/pipeline_parallel.cc b/infini_train/src/nn/parallel/pp/pipeline_parallel.cc index 77a9fa3..bbdceae 100644 --- a/infini_train/src/nn/parallel/pp/pipeline_parallel.cc +++ b/infini_train/src/nn/parallel/pp/pipeline_parallel.cc @@ -3,6 +3,7 @@ #include #include +#include #include "infini_train/include/nn/modules/container.h" #include "infini_train/include/nn/modules/module.h" @@ -17,10 +18,11 @@ constexpr char kModuleName[] = "module"; thread_local int pp_rank = 0; -void PipelineParallel::BuildPipelineStage(const std::shared_ptr &module, - const std::shared_ptr &optimizer, - const std::vector> &recv_shape, int device_id) { - pipeline_stage_ = std::make_shared(module, rank_, num_stages_, recv_shape, optimizer, device_id); +void PipelineParallel::BuildPipelineStage(const std::shared_ptr &optimizer, + const std::vector> &recv_shape, int device_id, + std::vector> &&chunks) { + pipeline_stage_ + = std::make_shared(rank_, num_stages_, recv_shape, optimizer, device_id, std::move(chunks)); } void PipelineParallel::SetupSchedule(int num_micro_batches) { @@ -78,13 +80,30 @@ StageInfo PipelineParallel::GetStageInfo(int total_layers, int pp_size, int rank PipelineParallel::PipelineParallel(const std::shared_ptr module, int num_stages, int num_micro_batches, const std::vector> &recv_shape, int pp_rank, - const std::shared_ptr &optimizer, int device_id) + const std::shared_ptr &optimizer, int device_id, int chunk_size) : num_stages_(num_stages), rank_(pp_rank) { modules_[kModuleName] = std::move(module); - BuildPipelineStage(module, optimizer, recv_shape, device_id); + int stage_id = pp_rank; + int stage_size = num_stages; + + std::vector> chunks; + for (int chunk_id = 0; chunk_id < chunk_size; ++chunk_id) { + std::vector> chunk_parts; + if (chunk_id == 0 && stage_id == 0) { + chunk_parts.push_back(module->mutable_module(kPPFirstStageName)); + } + chunk_parts.push_back(module->mutable_module(kPPChunkNamePrefix + std::to_string(chunk_id))); + if (chunk_id == chunk_size - 1 && stage_id == stage_size - 1) { + chunk_parts.push_back(module->mutable_module(kPPLastStageName)); + } + chunks.push_back(std::make_shared(std::move(chunk_parts))); + } + + BuildPipelineStage(optimizer, recv_shape, device_id, std::move(chunks)); SetupSchedule(num_micro_batches); } +std::vector> *PipelineParallel::mutable_chunks() { return pipeline_stage_->mutable_chunks(); } } // namespace infini_train::nn::parallel diff --git a/infini_train/src/nn/parallel/pp/pipeline_stage.cc b/infini_train/src/nn/parallel/pp/pipeline_stage.cc index fcec34b..3b77fab 100644 --- a/infini_train/src/nn/parallel/pp/pipeline_stage.cc +++ b/infini_train/src/nn/parallel/pp/pipeline_stage.cc @@ -10,17 +10,14 @@ namespace infini_train::nn::parallel { -PipelineStage::PipelineStage(const std::shared_ptr &model, int stage_index /* pp_rank */, - int num_stages /* pp_size */, const std::vector> &recv_shape, - std::shared_ptr optimizer, int device_id) - : model_(model), stage_index_(stage_index), num_stages_(num_stages), - prev_rank_(stage_index > 0 ? stage_index - 1 : -1), +PipelineStage::PipelineStage(int stage_index /* pp_rank */, int num_stages /* pp_size */, + const std::vector> &recv_shape, std::shared_ptr optimizer, + int device_id, std::vector> &&chunks) + : stage_index_(stage_index), num_stages_(num_stages), prev_rank_(stage_index > 0 ? stage_index - 1 : -1), next_rank_(stage_index < num_stages - 1 ? stage_index + 1 : -1), recv_shape_(recv_shape), optimizer_(std::move(optimizer)), - device_(DeviceManager::Instance()->GetAllAvailableDevices(DeviceType::kCUDA).at(device_id)) { - - chunks_ = model->BuildChunks(stage_index); -} + device_(DeviceManager::Instance()->GetAllAvailableDevices(DeviceType::kCUDA).at(device_id)), + chunks_(std::move(chunks)) {} std::vector> PipelineStage::ForwardOneChunk(const std::vector> &inputs, int local_chunk_idx) {