Skip to content
This repository was archived by the owner on Jan 10, 2023. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 67 additions & 48 deletions paragraph/translation/allgather/mesh_2d_allgather_translator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,11 @@ Mesh2dAllGatherTranslator::Mesh2dAllGatherTranslator(
if (config.find("concentration") != config.end()) {
concentration_ = config["concentration"].get<uint64_t>();
}
integrated_local_exchange_ = false;
if (config.find("integrated_local_exchange") != config.end()) {
integrated_local_exchange_ =
config["integrated_local_exchange"].get<bool>();
}

// Create json config for internal 1D Mesh all-gather
nlohmann::json implicit_config = R"(
Expand Down Expand Up @@ -76,63 +81,77 @@ shim::StatusOr<std::unique_ptr<Subroutine>>
absl::InvalidArgumentError) <<
"Processor index points to the wrong Processor ID.";
Instruction* previous_instruction = nullptr;
std::vector<uint64_t> processor_coordinates;
std::unordered_set<int64_t> whole_world(comm_group.begin(), comm_group.end());
// Check if we have non-trivial concentration first
if (concentration_ > 1) {
processor_coordinates = ConsecutiveProcessorIdToGridCoordinates(
processor_id, dimension_sizes_, concentration_);
CommunicationGroup comm_group_conc;
for (uint64_t i = 0; i < concentration_; i++) {
processor_coordinates.at(0) = i;
uint64_t new_processor_id = GridCoordinatesToConsecutiveProcessorId(
processor_coordinates, dimension_sizes_, concentration_);
if (whole_world.find(new_processor_id) != whole_world.end()) {
comm_group_conc.push_back(new_processor_id);
CommunicationGroup local_comm_group = CommunicationGroupLocalProjection(
processor_id, comm_group, dimension_sizes_, concentration_);
std::vector<double> stage_comm_sizes;
// We prepare communication sizes for each stage and each dimension as
// dimension and/or communication groups could be uneven
for (size_t dim = 0; dim < dimension_sizes_.size(); dim++) {
if ((comm_size == 0) || (comm_group.empty())) {
stage_comm_sizes.push_back(0);
} else {
stage_comm_sizes.push_back(
comm_size / comm_group.size() / dimension_sizes_.size());
if (integrated_local_exchange_) {
// Additionally accomodating split of traffic from concentrator among
// dimensions
stage_comm_sizes.at(dim) /= dimension_sizes_.size();
}
}
if (comm_group_conc.size() > 1) {
ASSIGN_OR_RETURN(auto allgather_conc, Instruction::Create(
Opcode::kAllGather,
absl::StrCat(name_prefix,
"_dim-conc"),
allgather_sub_ptr));
allgather_conc->AppendCommunicationGroup(comm_group_conc);
allgather_conc->SetBytesOut(comm_size * concentration_ /
comm_group.size());
RETURN_IF_ERROR(allgather_translator_->Translate(allgather_conc));
previous_instruction = allgather_conc;
}
}
// Now do the same for every dimension of the mesh
for (size_t dim = 0; dim < dimension_sizes_.size(); dim++) {
processor_coordinates = ConsecutiveProcessorIdToGridCoordinates(
processor_id, dimension_sizes_, concentration_);
CommunicationGroup comm_group_mesh;
uint64_t dim_width = dimension_sizes_.at(dim);
for (uint64_t i = 0; i < dim_width; i++) {
processor_coordinates.at(dim + 1) = i;
uint64_t new_processor_id = GridCoordinatesToConsecutiveProcessorId(
processor_coordinates, dimension_sizes_, concentration_);
if (whole_world.find(new_processor_id) != whole_world.end()) {
comm_group_mesh.push_back(new_processor_id);
// We have as many stages as dimensions in the Mesh
for (size_t stage = 0; stage < dimension_sizes_.size(); stage++) {
// We run AllGather in parallel for every dimension of Mesh
std::vector<Instruction*> parallel_allgather;
for (size_t dim = 0; dim < dimension_sizes_.size(); dim++) {
auto new_comm_group = CommunicationGroupProjectionOnGrid(
processor_id, comm_group, dim, integrated_local_exchange_,
dimension_sizes_, concentration_);
// Every new stage we should increase communication size
// On the first stage we only exchange data laying in the 1st dimension
// On the second stage we exchange data from both 1st and 2nd dimensions
stage_comm_sizes.at(dim) *= new_comm_group.size();
// If we don't have any communication in original comm_group within the
// current dimension, just leave it
if (new_comm_group.size() > 1) {
ASSIGN_OR_RETURN(auto allgather_stage, Instruction::Create(
Opcode::kAllGather,
absl::StrCat(name_prefix, "_stage-", stage, "_dim-", dim),
allgather_sub_ptr));
allgather_stage->AppendCommunicationGroup(new_comm_group);
allgather_stage->SetBytesOut(stage_comm_sizes.at(dim));
if (previous_instruction != nullptr) {
allgather_stage->AddOperand(previous_instruction);
}
RETURN_IF_ERROR(allgather_translator_->Translate(allgather_stage));
parallel_allgather.push_back(allgather_stage);
}
}
// If we don't have any communication in original comm_group within the
// current dimension, just leave it
if (comm_group_mesh.size() > 1) {
ASSIGN_OR_RETURN(auto allgather_mesh, Instruction::Create(
ASSIGN_OR_RETURN(auto allgather_root, Instruction::Create(
Opcode::kNull,
absl::StrCat(name_prefix, "_stage-", stage, "_root"),
allgather_sub_ptr));
previous_instruction = allgather_root;
for (auto& instr : parallel_allgather) {
allgather_root->AddOperand(instr);
}
}
// Check if we have non-trivial concentration and need to perform
// explicit local exchange step
if ((concentration_ > 1) && !integrated_local_exchange_) {
if (local_comm_group.size() > 1) {
ASSIGN_OR_RETURN(auto allgather_conc, Instruction::Create(
Opcode::kAllGather,
absl::StrCat(name_prefix, "_dim-", dim),
absl::StrCat(name_prefix,
"_conc"),
allgather_sub_ptr));
allgather_mesh->AppendCommunicationGroup(comm_group_mesh);
allgather_mesh->SetBytesOut(comm_size * dim_width /
comm_group.size());
allgather_conc->AppendCommunicationGroup(local_comm_group);
allgather_conc->SetBytesOut(comm_size);
if (previous_instruction != nullptr) {
allgather_mesh->AddOperand(previous_instruction);
allgather_conc->AddOperand(previous_instruction);
}
RETURN_IF_ERROR(allgather_translator_->Translate(allgather_mesh));
previous_instruction = allgather_mesh;
RETURN_IF_ERROR(allgather_translator_->Translate(allgather_conc));
previous_instruction = allgather_conc;
}
}
// Set root instruction for allgather subroutine
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ class Mesh2dAllGatherTranslator : public AllGatherTranslator {
std::vector<uint64_t> dimension_sizes_;
// Number of processors per mesh node
uint64_t concentration_;
// concentrators
bool integrated_local_exchange_;
};

} // namespace paragraph
Expand Down
Loading