diff --git a/paragraph/translation/allgather/mesh_2d_allgather_translator.cc b/paragraph/translation/allgather/mesh_2d_allgather_translator.cc index d4f5739..b4361aa 100644 --- a/paragraph/translation/allgather/mesh_2d_allgather_translator.cc +++ b/paragraph/translation/allgather/mesh_2d_allgather_translator.cc @@ -44,6 +44,11 @@ Mesh2dAllGatherTranslator::Mesh2dAllGatherTranslator( if (config.find("concentration") != config.end()) { concentration_ = config["concentration"].get(); } + integrated_local_exchange_ = false; + if (config.find("integrated_local_exchange") != config.end()) { + integrated_local_exchange_ = + config["integrated_local_exchange"].get(); + } // Create json config for internal 1D Mesh all-gather nlohmann::json implicit_config = R"( @@ -76,63 +81,77 @@ shim::StatusOr> absl::InvalidArgumentError) << "Processor index points to the wrong Processor ID."; Instruction* previous_instruction = nullptr; - std::vector processor_coordinates; - std::unordered_set whole_world(comm_group.begin(), comm_group.end()); - // Check if we have non-trivial concentration first - if (concentration_ > 1) { - processor_coordinates = ConsecutiveProcessorIdToGridCoordinates( - processor_id, dimension_sizes_, concentration_); - CommunicationGroup comm_group_conc; - for (uint64_t i = 0; i < concentration_; i++) { - processor_coordinates.at(0) = i; - uint64_t new_processor_id = GridCoordinatesToConsecutiveProcessorId( - processor_coordinates, dimension_sizes_, concentration_); - if (whole_world.find(new_processor_id) != whole_world.end()) { - comm_group_conc.push_back(new_processor_id); + CommunicationGroup local_comm_group = CommunicationGroupLocalProjection( + processor_id, comm_group, dimension_sizes_, concentration_); + std::vector stage_comm_sizes; + // We prepare communication sizes for each stage and each dimension as + // dimension and/or communication groups could be uneven + for (size_t dim = 0; dim < dimension_sizes_.size(); dim++) { + if ((comm_size == 0) || (comm_group.empty())) { + stage_comm_sizes.push_back(0); + } else { + stage_comm_sizes.push_back( + comm_size / comm_group.size() / dimension_sizes_.size()); + if (integrated_local_exchange_) { + // Additionally accomodating split of traffic from concentrator among + // dimensions + stage_comm_sizes.at(dim) /= dimension_sizes_.size(); } } - if (comm_group_conc.size() > 1) { - ASSIGN_OR_RETURN(auto allgather_conc, Instruction::Create( - Opcode::kAllGather, - absl::StrCat(name_prefix, - "_dim-conc"), - allgather_sub_ptr)); - allgather_conc->AppendCommunicationGroup(comm_group_conc); - allgather_conc->SetBytesOut(comm_size * concentration_ / - comm_group.size()); - RETURN_IF_ERROR(allgather_translator_->Translate(allgather_conc)); - previous_instruction = allgather_conc; - } } - // Now do the same for every dimension of the mesh - for (size_t dim = 0; dim < dimension_sizes_.size(); dim++) { - processor_coordinates = ConsecutiveProcessorIdToGridCoordinates( - processor_id, dimension_sizes_, concentration_); - CommunicationGroup comm_group_mesh; - uint64_t dim_width = dimension_sizes_.at(dim); - for (uint64_t i = 0; i < dim_width; i++) { - processor_coordinates.at(dim + 1) = i; - uint64_t new_processor_id = GridCoordinatesToConsecutiveProcessorId( - processor_coordinates, dimension_sizes_, concentration_); - if (whole_world.find(new_processor_id) != whole_world.end()) { - comm_group_mesh.push_back(new_processor_id); + // We have as many stages as dimensions in the Mesh + for (size_t stage = 0; stage < dimension_sizes_.size(); stage++) { + // We run AllGather in parallel for every dimension of Mesh + std::vector parallel_allgather; + for (size_t dim = 0; dim < dimension_sizes_.size(); dim++) { + auto new_comm_group = CommunicationGroupProjectionOnGrid( + processor_id, comm_group, dim, integrated_local_exchange_, + dimension_sizes_, concentration_); + // Every new stage we should increase communication size + // On the first stage we only exchange data laying in the 1st dimension + // On the second stage we exchange data from both 1st and 2nd dimensions + stage_comm_sizes.at(dim) *= new_comm_group.size(); + // If we don't have any communication in original comm_group within the + // current dimension, just leave it + if (new_comm_group.size() > 1) { + ASSIGN_OR_RETURN(auto allgather_stage, Instruction::Create( + Opcode::kAllGather, + absl::StrCat(name_prefix, "_stage-", stage, "_dim-", dim), + allgather_sub_ptr)); + allgather_stage->AppendCommunicationGroup(new_comm_group); + allgather_stage->SetBytesOut(stage_comm_sizes.at(dim)); + if (previous_instruction != nullptr) { + allgather_stage->AddOperand(previous_instruction); + } + RETURN_IF_ERROR(allgather_translator_->Translate(allgather_stage)); + parallel_allgather.push_back(allgather_stage); } } - // If we don't have any communication in original comm_group within the - // current dimension, just leave it - if (comm_group_mesh.size() > 1) { - ASSIGN_OR_RETURN(auto allgather_mesh, Instruction::Create( + ASSIGN_OR_RETURN(auto allgather_root, Instruction::Create( + Opcode::kNull, + absl::StrCat(name_prefix, "_stage-", stage, "_root"), + allgather_sub_ptr)); + previous_instruction = allgather_root; + for (auto& instr : parallel_allgather) { + allgather_root->AddOperand(instr); + } + } + // Check if we have non-trivial concentration and need to perform + // explicit local exchange step + if ((concentration_ > 1) && !integrated_local_exchange_) { + if (local_comm_group.size() > 1) { + ASSIGN_OR_RETURN(auto allgather_conc, Instruction::Create( Opcode::kAllGather, - absl::StrCat(name_prefix, "_dim-", dim), + absl::StrCat(name_prefix, + "_conc"), allgather_sub_ptr)); - allgather_mesh->AppendCommunicationGroup(comm_group_mesh); - allgather_mesh->SetBytesOut(comm_size * dim_width / - comm_group.size()); + allgather_conc->AppendCommunicationGroup(local_comm_group); + allgather_conc->SetBytesOut(comm_size); if (previous_instruction != nullptr) { - allgather_mesh->AddOperand(previous_instruction); + allgather_conc->AddOperand(previous_instruction); } - RETURN_IF_ERROR(allgather_translator_->Translate(allgather_mesh)); - previous_instruction = allgather_mesh; + RETURN_IF_ERROR(allgather_translator_->Translate(allgather_conc)); + previous_instruction = allgather_conc; } } // Set root instruction for allgather subroutine diff --git a/paragraph/translation/allgather/mesh_2d_allgather_translator.h b/paragraph/translation/allgather/mesh_2d_allgather_translator.h index 0aeea47..9e84002 100644 --- a/paragraph/translation/allgather/mesh_2d_allgather_translator.h +++ b/paragraph/translation/allgather/mesh_2d_allgather_translator.h @@ -59,6 +59,8 @@ class Mesh2dAllGatherTranslator : public AllGatherTranslator { std::vector dimension_sizes_; // Number of processors per mesh node uint64_t concentration_; + // concentrators + bool integrated_local_exchange_; }; } // namespace paragraph diff --git a/paragraph/translation/allgather/mesh_2d_allgather_translator_test.cc b/paragraph/translation/allgather/mesh_2d_allgather_translator_test.cc index 60eca60..604f847 100644 --- a/paragraph/translation/allgather/mesh_2d_allgather_translator_test.cc +++ b/paragraph/translation/allgather/mesh_2d_allgather_translator_test.cc @@ -24,45 +24,9 @@ #include "paragraph/shim/test_macros.h" #include "paragraph/translation/translation_map.h" -// Tests expanding 2D-Mesh all-gather -TEST(Mesh2dAllGather, NoBarrier) { - auto graph = absl::make_unique("test_graph", 1); - auto sub = absl::make_unique( - "test_subroutine", graph.get()); - auto sub_ptr = sub.get(); - graph->SetEntrySubroutine(std::move(sub)); - - ASSERT_OK_AND_ASSIGN(auto instr_1, paragraph::Instruction::Create( - paragraph::Opcode::kDelay, "first_instruction", sub_ptr)); - instr_1->SetOps(4); - - ASSERT_OK_AND_ASSIGN(auto allgather, - paragraph::Instruction::Create( - paragraph::Opcode::kAllGather, "all-gather", sub_ptr)); - allgather->SetBytesOut(80); - paragraph::CommunicationGroup allgather_group = {0, 1, 2, 3, 4, 5, 6, 7}; - allgather->AppendCommunicationGroup(allgather_group); - - ASSERT_OK_AND_ASSIGN(auto instr_3, paragraph::Instruction::Create( - paragraph::Opcode::kDelay, "last_instruction", sub_ptr, true)); - instr_3->SetOps(4); - - nlohmann::json config = R"( - { - "all-gather": { - "algorithm": "mesh-2d", - "concentration": 2, - "dimension_widths": [2, 2] - } - } - )"_json; - - ASSERT_OK_AND_ASSIGN(auto translators, paragraph::CreateTranslators( - paragraph::TranslatorType::kCollective, config)); - EXPECT_OK(translators["all-gather"]->Translate(allgather)); - - paragraph::InstructionProto allgather_proto; - std::string allgather_str = +paragraph::InstructionProto Mesh2dAllGather_no_barrier_test_proto() { + paragraph::InstructionProto proto; + std::string test_str = R"proto( name: "all-gather" opcode: "all-gather" @@ -80,25 +44,27 @@ communication_groups { } inner_subroutines { name: "all-gather_mesh-2d" - subroutine_root_id: 8 + subroutine_root_id: 37 execution_probability: 1 execution_count: 1 instructions { - name: "all-gather_dim-conc" + name: "all-gather_stage-0_dim-0" opcode: "all-gather" instruction_id: 4 - bytes_out: 20 + bytes_out: 40 communication_groups { group_ids: 0 group_ids: 1 + group_ids: 2 + group_ids: 3 } inner_subroutines { - name: "all-gather_dim-conc_mesh-1d" - subroutine_root_id: 5 + name: "all-gather_stage-0_dim-0_mesh-1d" + subroutine_root_id: 11 execution_probability: 1 execution_count: 1 instructions { - name: "all-gather_dim-conc_mesh-1d_ccw_sendrecv_0" + name: "all-gather_stage-0_dim-0_mesh-1d_ccw_sendrecv_0" opcode: "sendrecv" instruction_id: 5 bytes_in: 10 @@ -108,79 +74,352 @@ inner_subroutines { group_ids: 0 } } + instructions { + name: "all-gather_stage-0_dim-0_mesh-1d_cw_sendrecv_0" + opcode: "sendrecv" + instruction_id: 6 + bytes_in: 10 + bytes_out: 10 + communication_groups { + group_ids: 2 + group_ids: 2 + } + } + instructions { + name: "all-gather_stage-0_dim-0_mesh-1d_root_0" + opcode: "null" + instruction_id: 7 + operand_ids: 6 + operand_ids: 5 + } + instructions { + name: "all-gather_stage-0_dim-0_mesh-1d_ccw_send_1" + opcode: "send" + instruction_id: 8 + bytes_out: 10 + communication_groups { + group_ids: 0 + } + operand_ids: 7 + } + instructions { + name: "all-gather_stage-0_dim-0_mesh-1d_cw_sendrecv_1" + opcode: "sendrecv" + instruction_id: 9 + bytes_in: 10 + bytes_out: 10 + communication_groups { + group_ids: 2 + group_ids: 2 + } + operand_ids: 7 + } + instructions { + name: "all-gather_stage-0_dim-0_mesh-1d_root_1" + opcode: "null" + instruction_id: 10 + operand_ids: 9 + operand_ids: 8 + } + instructions { + name: "all-gather_stage-0_dim-0_mesh-1d_ccw_send_2" + opcode: "send" + instruction_id: 11 + bytes_out: 10 + communication_groups { + group_ids: 0 + } + operand_ids: 10 + } } } instructions { - name: "all-gather_dim-0" + name: "all-gather_stage-0_dim-1" opcode: "all-gather" - instruction_id: 6 - bytes_out: 20 + instruction_id: 12 + bytes_out: 40 communication_groups { + group_ids: 0 group_ids: 1 - group_ids: 3 + group_ids: 4 + group_ids: 5 } - operand_ids: 4 inner_subroutines { - name: "all-gather_dim-0_mesh-1d" - subroutine_root_id: 7 + name: "all-gather_stage-0_dim-1_mesh-1d" + subroutine_root_id: 19 execution_probability: 1 execution_count: 1 instructions { - name: "all-gather_dim-0_mesh-1d_cw_sendrecv_0" + name: "all-gather_stage-0_dim-1_mesh-1d_ccw_sendrecv_0" opcode: "sendrecv" - instruction_id: 7 + instruction_id: 13 bytes_in: 10 bytes_out: 10 communication_groups { - group_ids: 3 - group_ids: 3 + group_ids: 0 + group_ids: 0 + } + } + instructions { + name: "all-gather_stage-0_dim-1_mesh-1d_cw_sendrecv_0" + opcode: "sendrecv" + instruction_id: 14 + bytes_in: 10 + bytes_out: 10 + communication_groups { + group_ids: 4 + group_ids: 4 + } + } + instructions { + name: "all-gather_stage-0_dim-1_mesh-1d_root_0" + opcode: "null" + instruction_id: 15 + operand_ids: 14 + operand_ids: 13 + } + instructions { + name: "all-gather_stage-0_dim-1_mesh-1d_ccw_send_1" + opcode: "send" + instruction_id: 16 + bytes_out: 10 + communication_groups { + group_ids: 0 + } + operand_ids: 15 + } + instructions { + name: "all-gather_stage-0_dim-1_mesh-1d_cw_sendrecv_1" + opcode: "sendrecv" + instruction_id: 17 + bytes_in: 10 + bytes_out: 10 + communication_groups { + group_ids: 4 + group_ids: 4 } + operand_ids: 15 + } + instructions { + name: "all-gather_stage-0_dim-1_mesh-1d_root_1" + opcode: "null" + instruction_id: 18 + operand_ids: 17 + operand_ids: 16 + } + instructions { + name: "all-gather_stage-0_dim-1_mesh-1d_ccw_send_2" + opcode: "send" + instruction_id: 19 + bytes_out: 10 + communication_groups { + group_ids: 0 + } + operand_ids: 18 } } } instructions { - name: "all-gather_dim-1" + name: "all-gather_stage-0_root" + opcode: "null" + instruction_id: 20 + operand_ids: 4 + operand_ids: 12 + } + instructions { + name: "all-gather_stage-1_dim-0" opcode: "all-gather" - instruction_id: 8 - bytes_out: 20 + instruction_id: 21 + bytes_out: 80 + communication_groups { + group_ids: 0 + group_ids: 1 + group_ids: 2 + group_ids: 3 + } + operand_ids: 20 + inner_subroutines { + name: "all-gather_stage-1_dim-0_mesh-1d" + subroutine_root_id: 28 + execution_probability: 1 + execution_count: 1 + instructions { + name: "all-gather_stage-1_dim-0_mesh-1d_ccw_sendrecv_0" + opcode: "sendrecv" + instruction_id: 22 + bytes_in: 20 + bytes_out: 20 + communication_groups { + group_ids: 0 + group_ids: 0 + } + } + instructions { + name: "all-gather_stage-1_dim-0_mesh-1d_cw_sendrecv_0" + opcode: "sendrecv" + instruction_id: 23 + bytes_in: 20 + bytes_out: 20 + communication_groups { + group_ids: 2 + group_ids: 2 + } + } + instructions { + name: "all-gather_stage-1_dim-0_mesh-1d_root_0" + opcode: "null" + instruction_id: 24 + operand_ids: 23 + operand_ids: 22 + } + instructions { + name: "all-gather_stage-1_dim-0_mesh-1d_ccw_send_1" + opcode: "send" + instruction_id: 25 + bytes_out: 20 + communication_groups { + group_ids: 0 + } + operand_ids: 24 + } + instructions { + name: "all-gather_stage-1_dim-0_mesh-1d_cw_sendrecv_1" + opcode: "sendrecv" + instruction_id: 26 + bytes_in: 20 + bytes_out: 20 + communication_groups { + group_ids: 2 + group_ids: 2 + } + operand_ids: 24 + } + instructions { + name: "all-gather_stage-1_dim-0_mesh-1d_root_1" + opcode: "null" + instruction_id: 27 + operand_ids: 26 + operand_ids: 25 + } + instructions { + name: "all-gather_stage-1_dim-0_mesh-1d_ccw_send_2" + opcode: "send" + instruction_id: 28 + bytes_out: 20 + communication_groups { + group_ids: 0 + } + operand_ids: 27 + } + } + } + instructions { + name: "all-gather_stage-1_dim-1" + opcode: "all-gather" + instruction_id: 29 + bytes_out: 80 communication_groups { + group_ids: 0 group_ids: 1 + group_ids: 4 group_ids: 5 } - operand_ids: 6 + operand_ids: 20 inner_subroutines { - name: "all-gather_dim-1_mesh-1d" - subroutine_root_id: 9 + name: "all-gather_stage-1_dim-1_mesh-1d" + subroutine_root_id: 36 execution_probability: 1 execution_count: 1 instructions { - name: "all-gather_dim-1_mesh-1d_cw_sendrecv_0" + name: "all-gather_stage-1_dim-1_mesh-1d_ccw_sendrecv_0" opcode: "sendrecv" - instruction_id: 9 - bytes_in: 10 - bytes_out: 10 + instruction_id: 30 + bytes_in: 20 + bytes_out: 20 + communication_groups { + group_ids: 0 + group_ids: 0 + } + } + instructions { + name: "all-gather_stage-1_dim-1_mesh-1d_cw_sendrecv_0" + opcode: "sendrecv" + instruction_id: 31 + bytes_in: 20 + bytes_out: 20 communication_groups { - group_ids: 5 - group_ids: 5 + group_ids: 4 + group_ids: 4 + } + } + instructions { + name: "all-gather_stage-1_dim-1_mesh-1d_root_0" + opcode: "null" + instruction_id: 32 + operand_ids: 31 + operand_ids: 30 + } + instructions { + name: "all-gather_stage-1_dim-1_mesh-1d_ccw_send_1" + opcode: "send" + instruction_id: 33 + bytes_out: 20 + communication_groups { + group_ids: 0 + } + operand_ids: 32 + } + instructions { + name: "all-gather_stage-1_dim-1_mesh-1d_cw_sendrecv_1" + opcode: "sendrecv" + instruction_id: 34 + bytes_in: 20 + bytes_out: 20 + communication_groups { + group_ids: 4 + group_ids: 4 + } + operand_ids: 32 + } + instructions { + name: "all-gather_stage-1_dim-1_mesh-1d_root_1" + opcode: "null" + instruction_id: 35 + operand_ids: 34 + operand_ids: 33 + } + instructions { + name: "all-gather_stage-1_dim-1_mesh-1d_ccw_send_2" + opcode: "send" + instruction_id: 36 + bytes_out: 20 + communication_groups { + group_ids: 0 } + operand_ids: 35 } } } + instructions { + name: "all-gather_stage-1_root" + opcode: "null" + instruction_id: 37 + operand_ids: 21 + operand_ids: 29 + } } - )proto"; - google::protobuf::TextFormat::ParseFromString(allgather_str, - &allgather_proto); - EXPECT_TRUE(google::protobuf::util::MessageDifferencer::Equals( - allgather->ToProto().value(), allgather_proto)); -} + )proto"; + google::protobuf::TextFormat::ParseFromString(test_str, + &proto); + return proto; +} // NOLINT -// Tests expanding 1D-Mesh all-gather with barrier -TEST(Mesh2dAllGather, WithBarrier) { - auto graph = absl::make_unique("test_graph", 2); +// Tests expanding 2D-Mesh all-gather +TEST(Mesh2dAllGather, NoBarrier) { + auto graph = absl::make_unique("test_graph", 1); auto sub = absl::make_unique( "test_subroutine", graph.get()); auto sub_ptr = sub.get(); - sub_ptr->SetId(3); graph->SetEntrySubroutine(std::move(sub)); ASSERT_OK_AND_ASSIGN(auto instr_1, paragraph::Instruction::Create( @@ -204,247 +443,393 @@ TEST(Mesh2dAllGather, WithBarrier) { "algorithm": "mesh-2d", "concentration": 2, "dimension_widths": [2, 2], - "barrier": { - "algorithm": "centralized" + "integrated_local_exchange": true + } + } + )"_json; + + ASSERT_OK_AND_ASSIGN(auto translators, paragraph::CreateTranslators( + paragraph::TranslatorType::kCollective, config)); + EXPECT_OK(translators["all-gather"]->Translate(allgather)); + + paragraph::InstructionProto allgather_proto = + Mesh2dAllGather_no_barrier_test_proto(); + EXPECT_TRUE(google::protobuf::util::MessageDifferencer::Equals( + allgather->ToProto().value(), allgather_proto)); +} + +paragraph::InstructionProto Mesh2dAllGather_with_barrier_test_proto() { + paragraph::InstructionProto proto; + std::string test_str = + R"proto( +name: "all-gather" +opcode: "all-gather" +instruction_id: 2 +bytes_out: 80 +communication_groups { + group_ids: 0 + group_ids: 1 + group_ids: 2 + group_ids: 3 + group_ids: 4 + group_ids: 5 + group_ids: 6 + group_ids: 7 +} +inner_subroutines { + name: "all-gather_mesh-2d" + subroutine_root_id: 28 + execution_probability: 1 + execution_count: 1 + instructions { + name: "all-gather_stage-0_dim-0" + opcode: "all-gather" + instruction_id: 4 + bytes_out: 20 + communication_groups { + group_ids: 0 + group_ids: 2 + } + inner_subroutines { + name: "all-gather_stage-0_dim-0_mesh-1d" + subroutine_root_id: 8 + execution_probability: 1 + execution_count: 1 + instructions { + name: "all-gather_stage-0_dim-0_mesh-1d_barrier" + opcode: "barrier" + instruction_id: 5 + communication_groups { + group_ids: 0 + group_ids: 2 + } + inner_subroutines { + name: "all-gather_stage-0_dim-0_mesh-1d_barrier_centralized" + subroutine_root_id: 7 + execution_probability: 1 + execution_count: 1 + instructions { + name: "all-gather_stage-0_dim-0_mesh-1d_barrier_centralized_send_to_0" + opcode: "send" + instruction_id: 6 + communication_groups { + group_ids: 0 + } + } + instructions { + name: "all-gather_stage-0_dim-0_mesh-1d_barrier_centralized_recv_from_0" + opcode: "recv" + instruction_id: 7 + communication_groups { + group_ids: 0 + } + operand_ids: 6 + } + } + } + instructions { + name: "all-gather_stage-0_dim-0_mesh-1d_ccw_sendrecv_0" + opcode: "sendrecv" + instruction_id: 8 + bytes_in: 10 + bytes_out: 10 + communication_groups { + group_ids: 0 + group_ids: 0 } + operand_ids: 5 } } - )"_json; - - ASSERT_OK_AND_ASSIGN(auto translators, paragraph::CreateTranslators( - paragraph::TranslatorType::kCollective, config)); - EXPECT_OK(translators["all-gather"]->Translate(allgather)); - - paragraph::InstructionProto allgather_proto; - std::string allgather_str = - R"proto( -name: "all-gather" -opcode: "all-gather" -instruction_id: 2 -bytes_out: 80 -communication_groups { - group_ids: 0 - group_ids: 1 - group_ids: 2 - group_ids: 3 - group_ids: 4 - group_ids: 5 - group_ids: 6 - group_ids: 7 -} -inner_subroutines { - name: "all-gather_mesh-2d" - subroutine_root_id: 15 - execution_probability: 1 - execution_count: 1 + } instructions { - name: "all-gather_dim-conc" + name: "all-gather_stage-0_dim-1" opcode: "all-gather" - instruction_id: 4 + instruction_id: 9 bytes_out: 20 communication_groups { group_ids: 2 - group_ids: 3 + group_ids: 6 } inner_subroutines { - name: "all-gather_dim-conc_mesh-1d" - subroutine_root_id: 9 + name: "all-gather_stage-0_dim-1_mesh-1d" + subroutine_root_id: 14 execution_probability: 1 execution_count: 1 instructions { - name: "all-gather_dim-conc_mesh-1d_barrier" + name: "all-gather_stage-0_dim-1_mesh-1d_barrier" opcode: "barrier" - instruction_id: 5 + instruction_id: 10 communication_groups { group_ids: 2 - group_ids: 3 + group_ids: 6 } inner_subroutines { - name: "all-gather_dim-conc_mesh-1d_barrier_centralized" - subroutine_root_id: 8 + name: "all-gather_stage-0_dim-1_mesh-1d_barrier_centralized" + subroutine_root_id: 13 execution_probability: 1 execution_count: 1 instructions { - name: "all-gather_dim-conc_mesh-1d_barrier_centralized_coordinator_recv_from_3" + name: "all-gather_stage-0_dim-1_mesh-1d_barrier_centralized_coordinator_recv_from_6" opcode: "recv" - instruction_id: 6 + instruction_id: 11 communication_groups { - group_ids: 3 + group_ids: 6 } } instructions { - name: "all-gather_dim-conc_mesh-1d_barrier_centralized_coordinator_send_to_3" + name: "all-gather_stage-0_dim-1_mesh-1d_barrier_centralized_coordinator_send_to_6" opcode: "send" - instruction_id: 7 + instruction_id: 12 communication_groups { - group_ids: 3 + group_ids: 6 } - operand_ids: 6 + operand_ids: 11 } instructions { - name: "all-gather_dim-conc_mesh-1d_barrier_centralized_root_2" + name: "all-gather_stage-0_dim-1_mesh-1d_barrier_centralized_root_2" opcode: "null" - instruction_id: 8 - operand_ids: 7 + instruction_id: 13 + operand_ids: 12 } } } instructions { - name: "all-gather_dim-conc_mesh-1d_cw_sendrecv_0" + name: "all-gather_stage-0_dim-1_mesh-1d_cw_sendrecv_0" opcode: "sendrecv" - instruction_id: 9 + instruction_id: 14 bytes_in: 10 bytes_out: 10 communication_groups { - group_ids: 3 - group_ids: 3 + group_ids: 6 + group_ids: 6 } - operand_ids: 5 + operand_ids: 10 } } } instructions { - name: "all-gather_dim-0" + name: "all-gather_stage-0_root" + opcode: "null" + instruction_id: 15 + operand_ids: 4 + operand_ids: 9 + } + instructions { + name: "all-gather_stage-1_dim-0" opcode: "all-gather" - instruction_id: 10 - bytes_out: 20 + instruction_id: 16 + bytes_out: 40 communication_groups { group_ids: 0 group_ids: 2 } - operand_ids: 4 + operand_ids: 15 inner_subroutines { - name: "all-gather_dim-0_mesh-1d" - subroutine_root_id: 14 + name: "all-gather_stage-1_dim-0_mesh-1d" + subroutine_root_id: 20 execution_probability: 1 execution_count: 1 instructions { - name: "all-gather_dim-0_mesh-1d_barrier" + name: "all-gather_stage-1_dim-0_mesh-1d_barrier" opcode: "barrier" - instruction_id: 11 + instruction_id: 17 communication_groups { group_ids: 0 group_ids: 2 } inner_subroutines { - name: "all-gather_dim-0_mesh-1d_barrier_centralized" - subroutine_root_id: 13 + name: "all-gather_stage-1_dim-0_mesh-1d_barrier_centralized" + subroutine_root_id: 19 execution_probability: 1 execution_count: 1 instructions { - name: "all-gather_dim-0_mesh-1d_barrier_centralized_send_to_0" + name: "all-gather_stage-1_dim-0_mesh-1d_barrier_centralized_send_to_0" opcode: "send" - instruction_id: 12 + instruction_id: 18 communication_groups { group_ids: 0 } } instructions { - name: "all-gather_dim-0_mesh-1d_barrier_centralized_recv_from_0" + name: "all-gather_stage-1_dim-0_mesh-1d_barrier_centralized_recv_from_0" opcode: "recv" - instruction_id: 13 + instruction_id: 19 communication_groups { group_ids: 0 } - operand_ids: 12 + operand_ids: 18 } } } instructions { - name: "all-gather_dim-0_mesh-1d_ccw_sendrecv_0" + name: "all-gather_stage-1_dim-0_mesh-1d_ccw_sendrecv_0" opcode: "sendrecv" - instruction_id: 14 - bytes_in: 10 - bytes_out: 10 + instruction_id: 20 + bytes_in: 20 + bytes_out: 20 communication_groups { group_ids: 0 group_ids: 0 } - operand_ids: 11 + operand_ids: 17 } } } instructions { - name: "all-gather_dim-1" + name: "all-gather_stage-1_dim-1" opcode: "all-gather" - instruction_id: 15 - bytes_out: 20 + instruction_id: 21 + bytes_out: 40 communication_groups { group_ids: 2 group_ids: 6 } - operand_ids: 10 + operand_ids: 15 inner_subroutines { - name: "all-gather_dim-1_mesh-1d" - subroutine_root_id: 20 + name: "all-gather_stage-1_dim-1_mesh-1d" + subroutine_root_id: 26 execution_probability: 1 execution_count: 1 instructions { - name: "all-gather_dim-1_mesh-1d_barrier" + name: "all-gather_stage-1_dim-1_mesh-1d_barrier" opcode: "barrier" - instruction_id: 16 + instruction_id: 22 communication_groups { group_ids: 2 group_ids: 6 } inner_subroutines { - name: "all-gather_dim-1_mesh-1d_barrier_centralized" - subroutine_root_id: 19 + name: "all-gather_stage-1_dim-1_mesh-1d_barrier_centralized" + subroutine_root_id: 25 execution_probability: 1 execution_count: 1 instructions { - name: "all-gather_dim-1_mesh-1d_barrier_centralized_coordinator_recv_from_6" + name: "all-gather_stage-1_dim-1_mesh-1d_barrier_centralized_coordinator_recv_from_6" opcode: "recv" - instruction_id: 17 + instruction_id: 23 communication_groups { group_ids: 6 } } instructions { - name: "all-gather_dim-1_mesh-1d_barrier_centralized_coordinator_send_to_6" + name: "all-gather_stage-1_dim-1_mesh-1d_barrier_centralized_coordinator_send_to_6" opcode: "send" - instruction_id: 18 + instruction_id: 24 communication_groups { group_ids: 6 } - operand_ids: 17 + operand_ids: 23 } instructions { - name: "all-gather_dim-1_mesh-1d_barrier_centralized_root_2" + name: "all-gather_stage-1_dim-1_mesh-1d_barrier_centralized_root_2" opcode: "null" - instruction_id: 19 - operand_ids: 18 + instruction_id: 25 + operand_ids: 24 } } } instructions { - name: "all-gather_dim-1_mesh-1d_cw_sendrecv_0" + name: "all-gather_stage-1_dim-1_mesh-1d_cw_sendrecv_0" opcode: "sendrecv" - instruction_id: 20 - bytes_in: 10 - bytes_out: 10 + instruction_id: 26 + bytes_in: 20 + bytes_out: 20 communication_groups { group_ids: 6 group_ids: 6 } - operand_ids: 16 + operand_ids: 22 + } + } + } + instructions { + name: "all-gather_stage-1_root" + opcode: "null" + instruction_id: 27 + operand_ids: 16 + operand_ids: 21 + } + instructions { + name: "all-gather_conc" + opcode: "all-gather" + instruction_id: 28 + bytes_out: 80 + communication_groups { + group_ids: 2 + group_ids: 3 + } + operand_ids: 27 + inner_subroutines { + name: "all-gather_conc_mesh-1d" + subroutine_root_id: 33 + execution_probability: 1 + execution_count: 1 + instructions { + name: "all-gather_conc_mesh-1d_barrier" + opcode: "barrier" + instruction_id: 29 + communication_groups { + group_ids: 2 + group_ids: 3 + } + inner_subroutines { + name: "all-gather_conc_mesh-1d_barrier_centralized" + subroutine_root_id: 32 + execution_probability: 1 + execution_count: 1 + instructions { + name: "all-gather_conc_mesh-1d_barrier_centralized_coordinator_recv_from_3" + opcode: "recv" + instruction_id: 30 + communication_groups { + group_ids: 3 + } + } + instructions { + name: "all-gather_conc_mesh-1d_barrier_centralized_coordinator_send_to_3" + opcode: "send" + instruction_id: 31 + communication_groups { + group_ids: 3 + } + operand_ids: 30 + } + instructions { + name: "all-gather_conc_mesh-1d_barrier_centralized_root_2" + opcode: "null" + instruction_id: 32 + operand_ids: 31 + } + } + } + instructions { + name: "all-gather_conc_mesh-1d_cw_sendrecv_0" + opcode: "sendrecv" + instruction_id: 33 + bytes_in: 40 + bytes_out: 40 + communication_groups { + group_ids: 3 + group_ids: 3 + } + operand_ids: 29 } } } } - )proto"; - google::protobuf::TextFormat::ParseFromString(allgather_str, - &allgather_proto); - EXPECT_TRUE(google::protobuf::util::MessageDifferencer::Equals( - allgather->ToProto().value(), allgather_proto)); -} + )proto"; + google::protobuf::TextFormat::ParseFromString(test_str, + &proto); + return proto; +} // NOLINT -// Tests expanding 1D-Mesh all-gather -TEST(Mesh2dAllGather, InconsecutiveProcessors) { +// Tests expanding 1D-Mesh all-gather with barrier +TEST(Mesh2dAllGather, WithBarrier) { auto graph = absl::make_unique("test_graph", 2); auto sub = absl::make_unique( "test_subroutine", graph.get()); auto sub_ptr = sub.get(); + sub_ptr->SetId(3); graph->SetEntrySubroutine(std::move(sub)); ASSERT_OK_AND_ASSIGN(auto instr_1, paragraph::Instruction::Create( @@ -454,8 +839,8 @@ TEST(Mesh2dAllGather, InconsecutiveProcessors) { ASSERT_OK_AND_ASSIGN(auto allgather, paragraph::Instruction::Create( paragraph::Opcode::kAllGather, "all-gather", sub_ptr)); - allgather->SetBytesOut(48); - paragraph::CommunicationGroup allgather_group = {0, 2, 4}; + allgather->SetBytesOut(80); + paragraph::CommunicationGroup allgather_group = {0, 1, 2, 3, 4, 5, 6, 7}; allgather->AppendCommunicationGroup(allgather_group); ASSERT_OK_AND_ASSIGN(auto instr_3, paragraph::Instruction::Create( @@ -466,7 +851,11 @@ TEST(Mesh2dAllGather, InconsecutiveProcessors) { { "all-gather": { "algorithm": "mesh-2d", - "dimension_widths": [2, 3] + "concentration": 2, + "dimension_widths": [2, 2], + "barrier": { + "algorithm": "centralized" + } } } )"_json; @@ -475,8 +864,16 @@ TEST(Mesh2dAllGather, InconsecutiveProcessors) { paragraph::TranslatorType::kCollective, config)); EXPECT_OK(translators["all-gather"]->Translate(allgather)); - paragraph::InstructionProto allgather_proto; - std::string allgather_str = + paragraph::InstructionProto allgather_proto = + Mesh2dAllGather_with_barrier_test_proto(); + EXPECT_TRUE(google::protobuf::util::MessageDifferencer::Equals( + allgather->ToProto().value(), allgather_proto)); +} + +paragraph::InstructionProto +Mesh2dAllGather_inconsecutive_proc_test_proto() { + paragraph::InstructionProto proto; + std::string test_str = R"proto( name: "all-gather" opcode: "all-gather" @@ -489,11 +886,11 @@ communication_groups { } inner_subroutines { name: "all-gather_mesh-2d" - subroutine_root_id: 4 + subroutine_root_id: 19 execution_probability: 1 execution_count: 1 instructions { - name: "all-gather_dim-1" + name: "all-gather_stage-0_dim-1" opcode: "all-gather" instruction_id: 4 bytes_out: 48 @@ -503,12 +900,12 @@ inner_subroutines { group_ids: 4 } inner_subroutines { - name: "all-gather_dim-1_mesh-1d" + name: "all-gather_stage-0_dim-1_mesh-1d" subroutine_root_id: 10 execution_probability: 1 execution_count: 1 instructions { - name: "all-gather_dim-1_mesh-1d_ccw_sendrecv_0" + name: "all-gather_stage-0_dim-1_mesh-1d_ccw_sendrecv_0" opcode: "sendrecv" instruction_id: 5 bytes_in: 16 @@ -519,7 +916,7 @@ inner_subroutines { } } instructions { - name: "all-gather_dim-1_mesh-1d_cw_sendrecv_0" + name: "all-gather_stage-0_dim-1_mesh-1d_cw_sendrecv_0" opcode: "sendrecv" instruction_id: 6 bytes_in: 16 @@ -530,14 +927,14 @@ inner_subroutines { } } instructions { - name: "all-gather_dim-1_mesh-1d_root_0" + name: "all-gather_stage-0_dim-1_mesh-1d_root_0" opcode: "null" instruction_id: 7 operand_ids: 6 operand_ids: 5 } instructions { - name: "all-gather_dim-1_mesh-1d_ccw_send_1" + name: "all-gather_stage-0_dim-1_mesh-1d_ccw_send_1" opcode: "send" instruction_id: 8 bytes_out: 16 @@ -547,7 +944,7 @@ inner_subroutines { operand_ids: 7 } instructions { - name: "all-gather_dim-1_mesh-1d_cw_send_1" + name: "all-gather_stage-0_dim-1_mesh-1d_cw_send_1" opcode: "send" instruction_id: 9 bytes_out: 16 @@ -557,7 +954,7 @@ inner_subroutines { operand_ids: 7 } instructions { - name: "all-gather_dim-1_mesh-1d_root_1" + name: "all-gather_stage-0_dim-1_mesh-1d_root_1" opcode: "null" instruction_id: 10 operand_ids: 9 @@ -565,10 +962,137 @@ inner_subroutines { } } } + instructions { + name: "all-gather_stage-0_root" + opcode: "null" + instruction_id: 11 + operand_ids: 4 + } + instructions { + name: "all-gather_stage-1_dim-1" + opcode: "all-gather" + instruction_id: 12 + bytes_out: 144 + communication_groups { + group_ids: 0 + group_ids: 2 + group_ids: 4 + } + operand_ids: 11 + inner_subroutines { + name: "all-gather_stage-1_dim-1_mesh-1d" + subroutine_root_id: 18 + execution_probability: 1 + execution_count: 1 + instructions { + name: "all-gather_stage-1_dim-1_mesh-1d_ccw_sendrecv_0" + opcode: "sendrecv" + instruction_id: 13 + bytes_in: 48 + bytes_out: 48 + communication_groups { + group_ids: 0 + group_ids: 0 + } + } + instructions { + name: "all-gather_stage-1_dim-1_mesh-1d_cw_sendrecv_0" + opcode: "sendrecv" + instruction_id: 14 + bytes_in: 48 + bytes_out: 48 + communication_groups { + group_ids: 4 + group_ids: 4 + } + } + instructions { + name: "all-gather_stage-1_dim-1_mesh-1d_root_0" + opcode: "null" + instruction_id: 15 + operand_ids: 14 + operand_ids: 13 + } + instructions { + name: "all-gather_stage-1_dim-1_mesh-1d_ccw_send_1" + opcode: "send" + instruction_id: 16 + bytes_out: 48 + communication_groups { + group_ids: 0 + } + operand_ids: 15 + } + instructions { + name: "all-gather_stage-1_dim-1_mesh-1d_cw_send_1" + opcode: "send" + instruction_id: 17 + bytes_out: 48 + communication_groups { + group_ids: 4 + } + operand_ids: 15 + } + instructions { + name: "all-gather_stage-1_dim-1_mesh-1d_root_1" + opcode: "null" + instruction_id: 18 + operand_ids: 17 + operand_ids: 16 + } + } + } + instructions { + name: "all-gather_stage-1_root" + opcode: "null" + instruction_id: 19 + operand_ids: 12 + } } - )proto"; - google::protobuf::TextFormat::ParseFromString(allgather_str, - &allgather_proto); + )proto"; + google::protobuf::TextFormat::ParseFromString(test_str, + &proto); + return proto; +} // NOLINT + +// Tests expanding 1D-Mesh all-gather +TEST(Mesh2dAllGather, InconsecutiveProcessors) { + auto graph = absl::make_unique("test_graph", 2); + auto sub = absl::make_unique( + "test_subroutine", graph.get()); + auto sub_ptr = sub.get(); + graph->SetEntrySubroutine(std::move(sub)); + + ASSERT_OK_AND_ASSIGN(auto instr_1, paragraph::Instruction::Create( + paragraph::Opcode::kDelay, "first_instruction", sub_ptr)); + instr_1->SetOps(4); + + ASSERT_OK_AND_ASSIGN(auto allgather, + paragraph::Instruction::Create( + paragraph::Opcode::kAllGather, "all-gather", sub_ptr)); + allgather->SetBytesOut(48); + paragraph::CommunicationGroup allgather_group = {0, 2, 4}; + allgather->AppendCommunicationGroup(allgather_group); + + ASSERT_OK_AND_ASSIGN(auto instr_3, paragraph::Instruction::Create( + paragraph::Opcode::kDelay, "last_instruction", sub_ptr, true)); + instr_3->SetOps(4); + + nlohmann::json config = R"( + { + "all-gather": { + "algorithm": "mesh-2d", + "dimension_widths": [2, 3] + } + } + )"_json; + + ASSERT_OK_AND_ASSIGN(auto translators, paragraph::CreateTranslators( + paragraph::TranslatorType::kCollective, config)); + EXPECT_OK(translators["all-gather"]->Translate(allgather)); + + paragraph::InstructionProto allgather_proto = + Mesh2dAllGather_inconsecutive_proc_test_proto(); EXPECT_TRUE(google::protobuf::util::MessageDifferencer::Equals( allgather->ToProto().value(), allgather_proto)); } diff --git a/paragraph/translation/allgather/torus_2d_allgather_translator.cc b/paragraph/translation/allgather/torus_2d_allgather_translator.cc index ee7f214..0fd273e 100644 --- a/paragraph/translation/allgather/torus_2d_allgather_translator.cc +++ b/paragraph/translation/allgather/torus_2d_allgather_translator.cc @@ -1,4 +1,7 @@ /* Copyright 2021 Google LLC + if (integrated_local_exchange_) { + stage_comm_sizes.at(dim) /= local_comm_group.size(); + } * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -44,6 +47,12 @@ Torus2dAllGatherTranslator::Torus2dAllGatherTranslator( if (config.find("concentration") != config.end()) { concentration_ = config["concentration"].get(); } + // conentrated ports + integrated_local_exchange_ = false; + if (config.find("integrated_local_exchange") != config.end()) { + integrated_local_exchange_ = + config["integrated_local_exchange"].get(); + } // Create json config for internal 1D Torus all-gather nlohmann::json implicit_config = R"( @@ -76,63 +85,77 @@ shim::StatusOr> absl::InvalidArgumentError) << "Processor index points to the wrong Processor ID."; Instruction* previous_instruction = nullptr; - std::vector processor_coordinates; - std::unordered_set whole_world(comm_group.begin(), comm_group.end()); - // Check if we have non-trivial concentration first - if (concentration_ > 1) { - processor_coordinates = ConsecutiveProcessorIdToGridCoordinates( - processor_id, dimension_sizes_, concentration_); - CommunicationGroup comm_group_conc; - for (uint64_t i = 0; i < concentration_; i++) { - processor_coordinates.at(0) = i; - uint64_t new_processor_id = GridCoordinatesToConsecutiveProcessorId( - processor_coordinates, dimension_sizes_, concentration_); - if (whole_world.find(new_processor_id) != whole_world.end()) { - comm_group_conc.push_back(new_processor_id); + CommunicationGroup local_comm_group = CommunicationGroupLocalProjection( + processor_id, comm_group, dimension_sizes_, concentration_); + std::vector stage_comm_sizes; + // We prepare communication sizes for each stage and each dimension as + // dimension and/or communication groups could be uneven + for (size_t dim = 0; dim < dimension_sizes_.size(); dim++) { + if ((comm_size == 0) || (comm_group.empty())) { + stage_comm_sizes.push_back(0); + } else { + stage_comm_sizes.push_back( + comm_size / comm_group.size() / dimension_sizes_.size()); + if (integrated_local_exchange_) { + // Additionally accomodating split of traffic from concentrator among + // dimensions + stage_comm_sizes.at(dim) /= dimension_sizes_.size(); } } - if (comm_group_conc.size() > 1) { - ASSIGN_OR_RETURN(auto allgather_conc, Instruction::Create( - Opcode::kAllGather, - absl::StrCat(name_prefix, - "_dim-conc"), - allgather_sub_ptr)); - allgather_conc->AppendCommunicationGroup(comm_group_conc); - allgather_conc->SetBytesOut(comm_size * concentration_ / - comm_group.size()); - RETURN_IF_ERROR(allgather_translator_->Translate(allgather_conc)); - previous_instruction = allgather_conc; - } } - // Now do the same for every dimension of the torus - for (size_t dim = 0; dim < dimension_sizes_.size(); dim++) { - processor_coordinates = ConsecutiveProcessorIdToGridCoordinates( - processor_id, dimension_sizes_, concentration_); - CommunicationGroup comm_group_torus; - uint64_t dim_width = dimension_sizes_.at(dim); - for (uint64_t i = 0; i < dim_width; i++) { - processor_coordinates.at(dim + 1) = i; - uint64_t new_processor_id = GridCoordinatesToConsecutiveProcessorId( - processor_coordinates, dimension_sizes_, concentration_); - if (whole_world.find(new_processor_id) != whole_world.end()) { - comm_group_torus.push_back(new_processor_id); + // We have as many stages as dimensions in the Torus + for (size_t stage = 0; stage < dimension_sizes_.size(); stage++) { + // We run AllGather in parallel for every dimension of Torus + std::vector parallel_allgather; + for (size_t dim = 0; dim < dimension_sizes_.size(); dim++) { + auto new_comm_group = CommunicationGroupProjectionOnGrid( + processor_id, comm_group, dim, integrated_local_exchange_, + dimension_sizes_, concentration_); + // Every new stage we should increase communication size + // On the first stage we only exchange data laying in the 1st dimension + // On the second stage we exchange data from both 1st and 2nd dimensions + stage_comm_sizes.at(dim) *= new_comm_group.size(); + // If we don't have any communication in original comm_group within the + // current dimension, just leave it + if (new_comm_group.size() > 1) { + ASSIGN_OR_RETURN(auto allgather_stage, Instruction::Create( + Opcode::kAllGather, + absl::StrCat(name_prefix, "_stage-", stage, "_dim-", dim), + allgather_sub_ptr)); + allgather_stage->AppendCommunicationGroup(new_comm_group); + allgather_stage->SetBytesOut(stage_comm_sizes.at(dim)); + if (previous_instruction != nullptr) { + allgather_stage->AddOperand(previous_instruction); + } + RETURN_IF_ERROR(allgather_translator_->Translate(allgather_stage)); + parallel_allgather.push_back(allgather_stage); } } - // If we don't have any communication in original comm_group within the - // current dimension, just leave it - if (comm_group_torus.size() > 1) { - ASSIGN_OR_RETURN(auto allgather_torus, Instruction::Create( + ASSIGN_OR_RETURN(auto allgather_root, Instruction::Create( + Opcode::kNull, + absl::StrCat(name_prefix, "_stage-", stage, "_root"), + allgather_sub_ptr)); + previous_instruction = allgather_root; + for (auto& instr : parallel_allgather) { + allgather_root->AddOperand(instr); + } + } + // Check if we have non-trivial concentration and need to perform + // explicit local exchange step + if ((concentration_ > 1) && !integrated_local_exchange_) { + if (local_comm_group.size() > 1) { + ASSIGN_OR_RETURN(auto allgather_conc, Instruction::Create( Opcode::kAllGather, - absl::StrCat(name_prefix, "_dim-", dim), + absl::StrCat(name_prefix, + "_conc"), allgather_sub_ptr)); - allgather_torus->AppendCommunicationGroup(comm_group_torus); - allgather_torus->SetBytesOut(comm_size * dim_width / - comm_group.size()); + allgather_conc->AppendCommunicationGroup(local_comm_group); + allgather_conc->SetBytesOut(comm_size); if (previous_instruction != nullptr) { - allgather_torus->AddOperand(previous_instruction); + allgather_conc->AddOperand(previous_instruction); } - RETURN_IF_ERROR(allgather_translator_->Translate(allgather_torus)); - previous_instruction = allgather_torus; + RETURN_IF_ERROR(allgather_translator_->Translate(allgather_conc)); + previous_instruction = allgather_conc; } } // Set root instruction for allgather subroutine diff --git a/paragraph/translation/allgather/torus_2d_allgather_translator.h b/paragraph/translation/allgather/torus_2d_allgather_translator.h index 1c83b7f..e38b764 100644 --- a/paragraph/translation/allgather/torus_2d_allgather_translator.h +++ b/paragraph/translation/allgather/torus_2d_allgather_translator.h @@ -55,6 +55,8 @@ class Torus2dAllGatherTranslator : public AllGatherTranslator { std::vector dimension_sizes_; // Number of processors per torus node uint64_t concentration_; + // concentrators + bool integrated_local_exchange_; }; } // namespace paragraph diff --git a/paragraph/translation/allgather/torus_2d_allgather_translator_test.cc b/paragraph/translation/allgather/torus_2d_allgather_translator_test.cc index 3454fa7..804124a 100644 --- a/paragraph/translation/allgather/torus_2d_allgather_translator_test.cc +++ b/paragraph/translation/allgather/torus_2d_allgather_translator_test.cc @@ -24,45 +24,9 @@ #include "paragraph/shim/test_macros.h" #include "paragraph/translation/translation_map.h" -// Tests expanding 2D-Torus all-gather -TEST(Torus2dAllGather, NoBarrier) { - auto graph = absl::make_unique("test_graph", 1); - auto sub = absl::make_unique( - "test_subroutine", graph.get()); - auto sub_ptr = sub.get(); - graph->SetEntrySubroutine(std::move(sub)); - - ASSERT_OK_AND_ASSIGN(auto instr_1, paragraph::Instruction::Create( - paragraph::Opcode::kDelay, "first_instruction", sub_ptr)); - instr_1->SetOps(4); - - ASSERT_OK_AND_ASSIGN(auto allgather, - paragraph::Instruction::Create( - paragraph::Opcode::kAllGather, "all-gather", sub_ptr)); - allgather->SetBytesOut(80); - paragraph::CommunicationGroup allgather_group = {0, 1, 2, 3, 4, 5, 6, 7}; - allgather->AppendCommunicationGroup(allgather_group); - - ASSERT_OK_AND_ASSIGN(auto instr_3, paragraph::Instruction::Create( - paragraph::Opcode::kDelay, "last_instruction", sub_ptr, true)); - instr_3->SetOps(4); - - nlohmann::json config = R"( - { - "all-gather": { - "algorithm": "torus-2d", - "concentration": 2, - "dimension_widths": [2, 2] - } - } - )"_json; - - ASSERT_OK_AND_ASSIGN(auto translators, paragraph::CreateTranslators( - paragraph::TranslatorType::kCollective, config)); - EXPECT_OK(translators["all-gather"]->Translate(allgather)); - - paragraph::InstructionProto allgather_proto; - std::string allgather_str = +paragraph::InstructionProto no_barrier_test_proto() { + paragraph::InstructionProto proto; + std::string test_str = R"proto( name: "all-gather" opcode: "all-gather" @@ -80,272 +44,577 @@ communication_groups { } inner_subroutines { name: "all-gather_torus-2d" - subroutine_root_id: 16 + subroutine_root_id: 45 execution_probability: 1 execution_count: 1 instructions { - name: "all-gather_dim-conc" + name: "all-gather_stage-0_dim-0" opcode: "all-gather" instruction_id: 4 - bytes_out: 20 + bytes_out: 40 communication_groups { group_ids: 0 group_ids: 1 + group_ids: 2 + group_ids: 3 } inner_subroutines { - name: "all-gather_dim-conc_bidir-ring" - subroutine_root_id: 9 + name: "all-gather_stage-0_dim-0_bidir-ring" + subroutine_root_id: 13 execution_probability: 1 execution_count: 1 instructions { - name: "all-gather_dim-conc_bidir-ring_cw" + name: "all-gather_stage-0_dim-0_bidir-ring_cw" opcode: "all-gather" instruction_id: 5 - bytes_out: 10 + bytes_out: 20 communication_groups { group_ids: 0 group_ids: 1 + group_ids: 2 + group_ids: 3 } inner_subroutines { - name: "all-gather_dim-conc_bidir-ring_cw_unidir-ring" - subroutine_root_id: 6 + name: "all-gather_stage-0_dim-0_bidir-ring_cw_unidir-ring" + subroutine_root_id: 8 execution_probability: 1 execution_count: 1 instructions { - name: "all-gather_dim-conc_bidir-ring_cw_unidir-ring_sendrecv_1" + name: "all-gather_stage-0_dim-0_bidir-ring_cw_unidir-ring_sendrecv_1" opcode: "sendrecv" instruction_id: 6 bytes_in: 5 bytes_out: 5 communication_groups { group_ids: 0 + group_ids: 2 + } + } + instructions { + name: "all-gather_stage-0_dim-0_bidir-ring_cw_unidir-ring_sendrecv_2" + opcode: "sendrecv" + instruction_id: 7 + bytes_in: 5 + bytes_out: 5 + communication_groups { + group_ids: 0 + group_ids: 2 + } + operand_ids: 6 + } + instructions { + name: "all-gather_stage-0_dim-0_bidir-ring_cw_unidir-ring_sendrecv_3" + opcode: "sendrecv" + instruction_id: 8 + bytes_in: 5 + bytes_out: 5 + communication_groups { group_ids: 0 + group_ids: 2 } + operand_ids: 7 } } } instructions { - name: "all-gather_dim-conc_bidir-ring_ccw" + name: "all-gather_stage-0_dim-0_bidir-ring_ccw" opcode: "all-gather" - instruction_id: 7 - bytes_out: 10 + instruction_id: 9 + bytes_out: 20 communication_groups { + group_ids: 3 + group_ids: 2 group_ids: 1 group_ids: 0 } inner_subroutines { - name: "all-gather_dim-conc_bidir-ring_ccw_unidir-ring" - subroutine_root_id: 8 + name: "all-gather_stage-0_dim-0_bidir-ring_ccw_unidir-ring" + subroutine_root_id: 12 execution_probability: 1 execution_count: 1 instructions { - name: "all-gather_dim-conc_bidir-ring_ccw_unidir-ring_sendrecv_1" + name: "all-gather_stage-0_dim-0_bidir-ring_ccw_unidir-ring_sendrecv_1" opcode: "sendrecv" - instruction_id: 8 + instruction_id: 10 + bytes_in: 5 + bytes_out: 5 + communication_groups { + group_ids: 2 + group_ids: 0 + } + } + instructions { + name: "all-gather_stage-0_dim-0_bidir-ring_ccw_unidir-ring_sendrecv_2" + opcode: "sendrecv" + instruction_id: 11 bytes_in: 5 bytes_out: 5 communication_groups { + group_ids: 2 group_ids: 0 + } + operand_ids: 10 + } + instructions { + name: "all-gather_stage-0_dim-0_bidir-ring_ccw_unidir-ring_sendrecv_3" + opcode: "sendrecv" + instruction_id: 12 + bytes_in: 5 + bytes_out: 5 + communication_groups { + group_ids: 2 group_ids: 0 } + operand_ids: 11 } } } instructions { - name: "all-gather_dim-conc_bidir-ring_root_1" + name: "all-gather_stage-0_dim-0_bidir-ring_root_1" opcode: "null" - instruction_id: 9 + instruction_id: 13 operand_ids: 5 - operand_ids: 7 + operand_ids: 9 } } } instructions { - name: "all-gather_dim-0" + name: "all-gather_stage-0_dim-1" opcode: "all-gather" - instruction_id: 10 - bytes_out: 20 + instruction_id: 14 + bytes_out: 40 communication_groups { + group_ids: 0 group_ids: 1 - group_ids: 3 + group_ids: 4 + group_ids: 5 } - operand_ids: 4 inner_subroutines { - name: "all-gather_dim-0_bidir-ring" - subroutine_root_id: 15 + name: "all-gather_stage-0_dim-1_bidir-ring" + subroutine_root_id: 23 execution_probability: 1 execution_count: 1 instructions { - name: "all-gather_dim-0_bidir-ring_cw" + name: "all-gather_stage-0_dim-1_bidir-ring_cw" opcode: "all-gather" - instruction_id: 11 - bytes_out: 10 + instruction_id: 15 + bytes_out: 20 communication_groups { + group_ids: 0 group_ids: 1 - group_ids: 3 + group_ids: 4 + group_ids: 5 } inner_subroutines { - name: "all-gather_dim-0_bidir-ring_cw_unidir-ring" - subroutine_root_id: 12 + name: "all-gather_stage-0_dim-1_bidir-ring_cw_unidir-ring" + subroutine_root_id: 18 execution_probability: 1 execution_count: 1 instructions { - name: "all-gather_dim-0_bidir-ring_cw_unidir-ring_sendrecv_1" + name: "all-gather_stage-0_dim-1_bidir-ring_cw_unidir-ring_sendrecv_1" opcode: "sendrecv" - instruction_id: 12 + instruction_id: 16 bytes_in: 5 bytes_out: 5 communication_groups { - group_ids: 3 - group_ids: 3 + group_ids: 0 + group_ids: 4 + } + } + instructions { + name: "all-gather_stage-0_dim-1_bidir-ring_cw_unidir-ring_sendrecv_2" + opcode: "sendrecv" + instruction_id: 17 + bytes_in: 5 + bytes_out: 5 + communication_groups { + group_ids: 0 + group_ids: 4 + } + operand_ids: 16 + } + instructions { + name: "all-gather_stage-0_dim-1_bidir-ring_cw_unidir-ring_sendrecv_3" + opcode: "sendrecv" + instruction_id: 18 + bytes_in: 5 + bytes_out: 5 + communication_groups { + group_ids: 0 + group_ids: 4 } + operand_ids: 17 } } } instructions { - name: "all-gather_dim-0_bidir-ring_ccw" + name: "all-gather_stage-0_dim-1_bidir-ring_ccw" opcode: "all-gather" - instruction_id: 13 - bytes_out: 10 + instruction_id: 19 + bytes_out: 20 communication_groups { - group_ids: 3 + group_ids: 5 + group_ids: 4 group_ids: 1 + group_ids: 0 } inner_subroutines { - name: "all-gather_dim-0_bidir-ring_ccw_unidir-ring" - subroutine_root_id: 14 + name: "all-gather_stage-0_dim-1_bidir-ring_ccw_unidir-ring" + subroutine_root_id: 22 execution_probability: 1 execution_count: 1 instructions { - name: "all-gather_dim-0_bidir-ring_ccw_unidir-ring_sendrecv_1" + name: "all-gather_stage-0_dim-1_bidir-ring_ccw_unidir-ring_sendrecv_1" opcode: "sendrecv" - instruction_id: 14 + instruction_id: 20 bytes_in: 5 bytes_out: 5 communication_groups { - group_ids: 3 - group_ids: 3 + group_ids: 4 + group_ids: 0 + } + } + instructions { + name: "all-gather_stage-0_dim-1_bidir-ring_ccw_unidir-ring_sendrecv_2" + opcode: "sendrecv" + instruction_id: 21 + bytes_in: 5 + bytes_out: 5 + communication_groups { + group_ids: 4 + group_ids: 0 + } + operand_ids: 20 + } + instructions { + name: "all-gather_stage-0_dim-1_bidir-ring_ccw_unidir-ring_sendrecv_3" + opcode: "sendrecv" + instruction_id: 22 + bytes_in: 5 + bytes_out: 5 + communication_groups { + group_ids: 4 + group_ids: 0 } + operand_ids: 21 } } } instructions { - name: "all-gather_dim-0_bidir-ring_root_1" + name: "all-gather_stage-0_dim-1_bidir-ring_root_1" opcode: "null" - instruction_id: 15 - operand_ids: 11 - operand_ids: 13 + instruction_id: 23 + operand_ids: 15 + operand_ids: 19 } } } instructions { - name: "all-gather_dim-1" + name: "all-gather_stage-0_root" + opcode: "null" + instruction_id: 24 + operand_ids: 4 + operand_ids: 14 + } + instructions { + name: "all-gather_stage-1_dim-0" opcode: "all-gather" - instruction_id: 16 - bytes_out: 20 + instruction_id: 25 + bytes_out: 80 communication_groups { + group_ids: 0 group_ids: 1 - group_ids: 5 + group_ids: 2 + group_ids: 3 } - operand_ids: 10 + operand_ids: 24 inner_subroutines { - name: "all-gather_dim-1_bidir-ring" - subroutine_root_id: 21 + name: "all-gather_stage-1_dim-0_bidir-ring" + subroutine_root_id: 34 execution_probability: 1 execution_count: 1 instructions { - name: "all-gather_dim-1_bidir-ring_cw" + name: "all-gather_stage-1_dim-0_bidir-ring_cw" opcode: "all-gather" - instruction_id: 17 - bytes_out: 10 + instruction_id: 26 + bytes_out: 40 communication_groups { + group_ids: 0 group_ids: 1 - group_ids: 5 + group_ids: 2 + group_ids: 3 } inner_subroutines { - name: "all-gather_dim-1_bidir-ring_cw_unidir-ring" - subroutine_root_id: 18 + name: "all-gather_stage-1_dim-0_bidir-ring_cw_unidir-ring" + subroutine_root_id: 29 execution_probability: 1 execution_count: 1 instructions { - name: "all-gather_dim-1_bidir-ring_cw_unidir-ring_sendrecv_1" + name: "all-gather_stage-1_dim-0_bidir-ring_cw_unidir-ring_sendrecv_1" opcode: "sendrecv" - instruction_id: 18 - bytes_in: 5 - bytes_out: 5 + instruction_id: 27 + bytes_in: 10 + bytes_out: 10 + communication_groups { + group_ids: 0 + group_ids: 2 + } + } + instructions { + name: "all-gather_stage-1_dim-0_bidir-ring_cw_unidir-ring_sendrecv_2" + opcode: "sendrecv" + instruction_id: 28 + bytes_in: 10 + bytes_out: 10 + communication_groups { + group_ids: 0 + group_ids: 2 + } + operand_ids: 27 + } + instructions { + name: "all-gather_stage-1_dim-0_bidir-ring_cw_unidir-ring_sendrecv_3" + opcode: "sendrecv" + instruction_id: 29 + bytes_in: 10 + bytes_out: 10 communication_groups { - group_ids: 5 - group_ids: 5 + group_ids: 0 + group_ids: 2 } + operand_ids: 28 } } } instructions { - name: "all-gather_dim-1_bidir-ring_ccw" + name: "all-gather_stage-1_dim-0_bidir-ring_ccw" opcode: "all-gather" - instruction_id: 19 - bytes_out: 10 + instruction_id: 30 + bytes_out: 40 communication_groups { - group_ids: 5 + group_ids: 3 + group_ids: 2 group_ids: 1 + group_ids: 0 } inner_subroutines { - name: "all-gather_dim-1_bidir-ring_ccw_unidir-ring" - subroutine_root_id: 20 + name: "all-gather_stage-1_dim-0_bidir-ring_ccw_unidir-ring" + subroutine_root_id: 33 execution_probability: 1 execution_count: 1 instructions { - name: "all-gather_dim-1_bidir-ring_ccw_unidir-ring_sendrecv_1" + name: "all-gather_stage-1_dim-0_bidir-ring_ccw_unidir-ring_sendrecv_1" opcode: "sendrecv" - instruction_id: 20 - bytes_in: 5 - bytes_out: 5 + instruction_id: 31 + bytes_in: 10 + bytes_out: 10 + communication_groups { + group_ids: 2 + group_ids: 0 + } + } + instructions { + name: "all-gather_stage-1_dim-0_bidir-ring_ccw_unidir-ring_sendrecv_2" + opcode: "sendrecv" + instruction_id: 32 + bytes_in: 10 + bytes_out: 10 + communication_groups { + group_ids: 2 + group_ids: 0 + } + operand_ids: 31 + } + instructions { + name: "all-gather_stage-1_dim-0_bidir-ring_ccw_unidir-ring_sendrecv_3" + opcode: "sendrecv" + instruction_id: 33 + bytes_in: 10 + bytes_out: 10 communication_groups { - group_ids: 5 - group_ids: 5 + group_ids: 2 + group_ids: 0 } + operand_ids: 32 } } } instructions { - name: "all-gather_dim-1_bidir-ring_root_1" + name: "all-gather_stage-1_dim-0_bidir-ring_root_1" opcode: "null" - instruction_id: 21 - operand_ids: 17 - operand_ids: 19 + instruction_id: 34 + operand_ids: 26 + operand_ids: 30 } } } -} - )proto"; - google::protobuf::TextFormat::ParseFromString(allgather_str, - &allgather_proto); - EXPECT_TRUE(google::protobuf::util::MessageDifferencer::Equals( - allgather->ToProto().value(), allgather_proto)); -} - -// Tests expanding 1D-Torus all-gather with barrier -TEST(Torus2dAllGather, WithBarrier) { - auto graph = absl::make_unique("test_graph", 2); - auto sub = absl::make_unique( - "test_subroutine", graph.get()); - auto sub_ptr = sub.get(); - sub_ptr->SetId(3); - graph->SetEntrySubroutine(std::move(sub)); - - ASSERT_OK_AND_ASSIGN(auto instr_1, paragraph::Instruction::Create( - paragraph::Opcode::kDelay, "first_instruction", sub_ptr)); - instr_1->SetOps(4); - - ASSERT_OK_AND_ASSIGN(auto allgather, - paragraph::Instruction::Create( - paragraph::Opcode::kAllGather, "all-gather", sub_ptr)); - allgather->SetBytesOut(80); - paragraph::CommunicationGroup allgather_group = {0, 1, 2, 3, 4, 5, 6, 7}; - allgather->AppendCommunicationGroup(allgather_group); - - ASSERT_OK_AND_ASSIGN(auto instr_3, paragraph::Instruction::Create( - paragraph::Opcode::kDelay, "last_instruction", sub_ptr, true)); + instructions { + name: "all-gather_stage-1_dim-1" + opcode: "all-gather" + instruction_id: 35 + bytes_out: 80 + communication_groups { + group_ids: 0 + group_ids: 1 + group_ids: 4 + group_ids: 5 + } + operand_ids: 24 + inner_subroutines { + name: "all-gather_stage-1_dim-1_bidir-ring" + subroutine_root_id: 44 + execution_probability: 1 + execution_count: 1 + instructions { + name: "all-gather_stage-1_dim-1_bidir-ring_cw" + opcode: "all-gather" + instruction_id: 36 + bytes_out: 40 + communication_groups { + group_ids: 0 + group_ids: 1 + group_ids: 4 + group_ids: 5 + } + inner_subroutines { + name: "all-gather_stage-1_dim-1_bidir-ring_cw_unidir-ring" + subroutine_root_id: 39 + execution_probability: 1 + execution_count: 1 + instructions { + name: "all-gather_stage-1_dim-1_bidir-ring_cw_unidir-ring_sendrecv_1" + opcode: "sendrecv" + instruction_id: 37 + bytes_in: 10 + bytes_out: 10 + communication_groups { + group_ids: 0 + group_ids: 4 + } + } + instructions { + name: "all-gather_stage-1_dim-1_bidir-ring_cw_unidir-ring_sendrecv_2" + opcode: "sendrecv" + instruction_id: 38 + bytes_in: 10 + bytes_out: 10 + communication_groups { + group_ids: 0 + group_ids: 4 + } + operand_ids: 37 + } + instructions { + name: "all-gather_stage-1_dim-1_bidir-ring_cw_unidir-ring_sendrecv_3" + opcode: "sendrecv" + instruction_id: 39 + bytes_in: 10 + bytes_out: 10 + communication_groups { + group_ids: 0 + group_ids: 4 + } + operand_ids: 38 + } + } + } + instructions { + name: "all-gather_stage-1_dim-1_bidir-ring_ccw" + opcode: "all-gather" + instruction_id: 40 + bytes_out: 40 + communication_groups { + group_ids: 5 + group_ids: 4 + group_ids: 1 + group_ids: 0 + } + inner_subroutines { + name: "all-gather_stage-1_dim-1_bidir-ring_ccw_unidir-ring" + subroutine_root_id: 43 + execution_probability: 1 + execution_count: 1 + instructions { + name: "all-gather_stage-1_dim-1_bidir-ring_ccw_unidir-ring_sendrecv_1" + opcode: "sendrecv" + instruction_id: 41 + bytes_in: 10 + bytes_out: 10 + communication_groups { + group_ids: 4 + group_ids: 0 + } + } + instructions { + name: "all-gather_stage-1_dim-1_bidir-ring_ccw_unidir-ring_sendrecv_2" + opcode: "sendrecv" + instruction_id: 42 + bytes_in: 10 + bytes_out: 10 + communication_groups { + group_ids: 4 + group_ids: 0 + } + operand_ids: 41 + } + instructions { + name: "all-gather_stage-1_dim-1_bidir-ring_ccw_unidir-ring_sendrecv_3" + opcode: "sendrecv" + instruction_id: 43 + bytes_in: 10 + bytes_out: 10 + communication_groups { + group_ids: 4 + group_ids: 0 + } + operand_ids: 42 + } + } + } + instructions { + name: "all-gather_stage-1_dim-1_bidir-ring_root_1" + opcode: "null" + instruction_id: 44 + operand_ids: 36 + operand_ids: 40 + } + } + } + instructions { + name: "all-gather_stage-1_root" + opcode: "null" + instruction_id: 45 + operand_ids: 25 + operand_ids: 35 + } +} + )proto"; + google::protobuf::TextFormat::ParseFromString(test_str, + &proto); + return proto; +} // NOLINT + +// Tests expanding 2D-Torus all-gather +TEST(Torus2dAllGather, NoBarrier) { + auto graph = absl::make_unique("test_graph", 1); + auto sub = absl::make_unique( + "test_subroutine", graph.get()); + auto sub_ptr = sub.get(); + graph->SetEntrySubroutine(std::move(sub)); + + ASSERT_OK_AND_ASSIGN(auto instr_1, paragraph::Instruction::Create( + paragraph::Opcode::kDelay, "first_instruction", sub_ptr)); + instr_1->SetOps(4); + + ASSERT_OK_AND_ASSIGN(auto allgather, + paragraph::Instruction::Create( + paragraph::Opcode::kAllGather, "all-gather", sub_ptr)); + allgather->SetBytesOut(80); + paragraph::CommunicationGroup allgather_group = {0, 1, 2, 3, 4, 5, 6, 7}; + allgather->AppendCommunicationGroup(allgather_group); + + ASSERT_OK_AND_ASSIGN(auto instr_3, paragraph::Instruction::Create( + paragraph::Opcode::kDelay, "last_instruction", sub_ptr, true)); instr_3->SetOps(4); nlohmann::json config = R"( @@ -354,9 +623,7 @@ TEST(Torus2dAllGather, WithBarrier) { "algorithm": "torus-2d", "concentration": 2, "dimension_widths": [2, 2], - "barrier": { - "algorithm": "centralized" - } + "integrated_local_exchange": true } } )"_json; @@ -365,8 +632,14 @@ TEST(Torus2dAllGather, WithBarrier) { paragraph::TranslatorType::kCollective, config)); EXPECT_OK(translators["all-gather"]->Translate(allgather)); - paragraph::InstructionProto allgather_proto; - std::string allgather_str = + paragraph::InstructionProto allgather_proto = no_barrier_test_proto(); + EXPECT_TRUE(google::protobuf::util::MessageDifferencer::Equals( + allgather->ToProto().value(), allgather_proto)); +} + +paragraph::InstructionProto with_barrier_test_proto() { + paragraph::InstructionProto proto; + std::string test_str = R"proto( name: "all-gather" opcode: "all-gather" @@ -384,370 +657,613 @@ communication_groups { } inner_subroutines { name: "all-gather_torus-2d" - subroutine_root_id: 23 + subroutine_root_id: 44 execution_probability: 1 execution_count: 1 instructions { - name: "all-gather_dim-conc" + name: "all-gather_stage-0_dim-0" opcode: "all-gather" instruction_id: 4 bytes_out: 20 communication_groups { + group_ids: 0 group_ids: 2 - group_ids: 3 } inner_subroutines { - name: "all-gather_dim-conc_bidir-ring" - subroutine_root_id: 13 + name: "all-gather_stage-0_dim-0_bidir-ring" + subroutine_root_id: 12 execution_probability: 1 execution_count: 1 instructions { - name: "all-gather_dim-conc_bidir-ring_barrier" + name: "all-gather_stage-0_dim-0_bidir-ring_barrier" opcode: "barrier" instruction_id: 5 communication_groups { + group_ids: 0 group_ids: 2 - group_ids: 3 } inner_subroutines { - name: "all-gather_dim-conc_bidir-ring_barrier_centralized" - subroutine_root_id: 8 + name: "all-gather_stage-0_dim-0_bidir-ring_barrier_centralized" + subroutine_root_id: 7 execution_probability: 1 execution_count: 1 instructions { - name: "all-gather_dim-conc_bidir-ring_barrier_centralized_coordinator_recv_from_3" - opcode: "recv" + name: "all-gather_stage-0_dim-0_bidir-ring_barrier_centralized_send_to_0" + opcode: "send" instruction_id: 6 communication_groups { - group_ids: 3 + group_ids: 0 } } instructions { - name: "all-gather_dim-conc_bidir-ring_barrier_centralized_coordinator_send_to_3" - opcode: "send" + name: "all-gather_stage-0_dim-0_bidir-ring_barrier_centralized_recv_from_0" + opcode: "recv" instruction_id: 7 communication_groups { - group_ids: 3 + group_ids: 0 } operand_ids: 6 } + } + } + instructions { + name: "all-gather_stage-0_dim-0_bidir-ring_cw" + opcode: "all-gather" + instruction_id: 8 + bytes_out: 10 + communication_groups { + group_ids: 0 + group_ids: 2 + } + operand_ids: 5 + inner_subroutines { + name: "all-gather_stage-0_dim-0_bidir-ring_cw_unidir-ring" + subroutine_root_id: 9 + execution_probability: 1 + execution_count: 1 instructions { - name: "all-gather_dim-conc_bidir-ring_barrier_centralized_root_2" - opcode: "null" - instruction_id: 8 - operand_ids: 7 + name: "all-gather_stage-0_dim-0_bidir-ring_cw_unidir-ring_sendrecv_1" + opcode: "sendrecv" + instruction_id: 9 + bytes_in: 5 + bytes_out: 5 + communication_groups { + group_ids: 0 + group_ids: 0 + } } } } instructions { - name: "all-gather_dim-conc_bidir-ring_cw" + name: "all-gather_stage-0_dim-0_bidir-ring_ccw" opcode: "all-gather" - instruction_id: 9 + instruction_id: 10 bytes_out: 10 communication_groups { group_ids: 2 - group_ids: 3 + group_ids: 0 } operand_ids: 5 inner_subroutines { - name: "all-gather_dim-conc_bidir-ring_cw_unidir-ring" - subroutine_root_id: 10 + name: "all-gather_stage-0_dim-0_bidir-ring_ccw_unidir-ring" + subroutine_root_id: 11 execution_probability: 1 execution_count: 1 instructions { - name: "all-gather_dim-conc_bidir-ring_cw_unidir-ring_sendrecv_1" + name: "all-gather_stage-0_dim-0_bidir-ring_ccw_unidir-ring_sendrecv_1" opcode: "sendrecv" - instruction_id: 10 + instruction_id: 11 bytes_in: 5 bytes_out: 5 communication_groups { - group_ids: 3 - group_ids: 3 + group_ids: 0 + group_ids: 0 } } } } instructions { - name: "all-gather_dim-conc_bidir-ring_ccw" + name: "all-gather_stage-0_dim-0_bidir-ring_root_2" + opcode: "null" + instruction_id: 12 + operand_ids: 8 + operand_ids: 10 + } + } + } + instructions { + name: "all-gather_stage-0_dim-1" + opcode: "all-gather" + instruction_id: 13 + bytes_out: 20 + communication_groups { + group_ids: 2 + group_ids: 6 + } + inner_subroutines { + name: "all-gather_stage-0_dim-1_bidir-ring" + subroutine_root_id: 22 + execution_probability: 1 + execution_count: 1 + instructions { + name: "all-gather_stage-0_dim-1_bidir-ring_barrier" + opcode: "barrier" + instruction_id: 14 + communication_groups { + group_ids: 2 + group_ids: 6 + } + inner_subroutines { + name: "all-gather_stage-0_dim-1_bidir-ring_barrier_centralized" + subroutine_root_id: 17 + execution_probability: 1 + execution_count: 1 + instructions { + name: "all-gather_stage-0_dim-1_bidir-ring_barrier_centralized_coordinator_recv_from_6" + opcode: "recv" + instruction_id: 15 + communication_groups { + group_ids: 6 + } + } + instructions { + name: "all-gather_stage-0_dim-1_bidir-ring_barrier_centralized_coordinator_send_to_6" + opcode: "send" + instruction_id: 16 + communication_groups { + group_ids: 6 + } + operand_ids: 15 + } + instructions { + name: "all-gather_stage-0_dim-1_bidir-ring_barrier_centralized_root_2" + opcode: "null" + instruction_id: 17 + operand_ids: 16 + } + } + } + instructions { + name: "all-gather_stage-0_dim-1_bidir-ring_cw" opcode: "all-gather" - instruction_id: 11 + instruction_id: 18 bytes_out: 10 communication_groups { - group_ids: 3 group_ids: 2 + group_ids: 6 + } + operand_ids: 14 + inner_subroutines { + name: "all-gather_stage-0_dim-1_bidir-ring_cw_unidir-ring" + subroutine_root_id: 19 + execution_probability: 1 + execution_count: 1 + instructions { + name: "all-gather_stage-0_dim-1_bidir-ring_cw_unidir-ring_sendrecv_1" + opcode: "sendrecv" + instruction_id: 19 + bytes_in: 5 + bytes_out: 5 + communication_groups { + group_ids: 6 + group_ids: 6 + } + } + } + } + instructions { + name: "all-gather_stage-0_dim-1_bidir-ring_ccw" + opcode: "all-gather" + instruction_id: 20 + bytes_out: 10 + communication_groups { + group_ids: 6 + group_ids: 2 + } + operand_ids: 14 + inner_subroutines { + name: "all-gather_stage-0_dim-1_bidir-ring_ccw_unidir-ring" + subroutine_root_id: 21 + execution_probability: 1 + execution_count: 1 + instructions { + name: "all-gather_stage-0_dim-1_bidir-ring_ccw_unidir-ring_sendrecv_1" + opcode: "sendrecv" + instruction_id: 21 + bytes_in: 5 + bytes_out: 5 + communication_groups { + group_ids: 6 + group_ids: 6 + } + } + } + } + instructions { + name: "all-gather_stage-0_dim-1_bidir-ring_root_2" + opcode: "null" + instruction_id: 22 + operand_ids: 18 + operand_ids: 20 + } + } + } + instructions { + name: "all-gather_stage-0_root" + opcode: "null" + instruction_id: 23 + operand_ids: 4 + operand_ids: 13 + } + instructions { + name: "all-gather_stage-1_dim-0" + opcode: "all-gather" + instruction_id: 24 + bytes_out: 40 + communication_groups { + group_ids: 0 + group_ids: 2 + } + operand_ids: 23 + inner_subroutines { + name: "all-gather_stage-1_dim-0_bidir-ring" + subroutine_root_id: 32 + execution_probability: 1 + execution_count: 1 + instructions { + name: "all-gather_stage-1_dim-0_bidir-ring_barrier" + opcode: "barrier" + instruction_id: 25 + communication_groups { + group_ids: 0 + group_ids: 2 + } + inner_subroutines { + name: "all-gather_stage-1_dim-0_bidir-ring_barrier_centralized" + subroutine_root_id: 27 + execution_probability: 1 + execution_count: 1 + instructions { + name: "all-gather_stage-1_dim-0_bidir-ring_barrier_centralized_send_to_0" + opcode: "send" + instruction_id: 26 + communication_groups { + group_ids: 0 + } + } + instructions { + name: "all-gather_stage-1_dim-0_bidir-ring_barrier_centralized_recv_from_0" + opcode: "recv" + instruction_id: 27 + communication_groups { + group_ids: 0 + } + operand_ids: 26 + } + } + } + instructions { + name: "all-gather_stage-1_dim-0_bidir-ring_cw" + opcode: "all-gather" + instruction_id: 28 + bytes_out: 20 + communication_groups { + group_ids: 0 + group_ids: 2 + } + operand_ids: 25 + inner_subroutines { + name: "all-gather_stage-1_dim-0_bidir-ring_cw_unidir-ring" + subroutine_root_id: 29 + execution_probability: 1 + execution_count: 1 + instructions { + name: "all-gather_stage-1_dim-0_bidir-ring_cw_unidir-ring_sendrecv_1" + opcode: "sendrecv" + instruction_id: 29 + bytes_in: 10 + bytes_out: 10 + communication_groups { + group_ids: 0 + group_ids: 0 + } + } + } + } + instructions { + name: "all-gather_stage-1_dim-0_bidir-ring_ccw" + opcode: "all-gather" + instruction_id: 30 + bytes_out: 20 + communication_groups { + group_ids: 2 + group_ids: 0 } - operand_ids: 5 + operand_ids: 25 inner_subroutines { - name: "all-gather_dim-conc_bidir-ring_ccw_unidir-ring" - subroutine_root_id: 12 + name: "all-gather_stage-1_dim-0_bidir-ring_ccw_unidir-ring" + subroutine_root_id: 31 execution_probability: 1 execution_count: 1 instructions { - name: "all-gather_dim-conc_bidir-ring_ccw_unidir-ring_sendrecv_1" + name: "all-gather_stage-1_dim-0_bidir-ring_ccw_unidir-ring_sendrecv_1" opcode: "sendrecv" - instruction_id: 12 - bytes_in: 5 - bytes_out: 5 + instruction_id: 31 + bytes_in: 10 + bytes_out: 10 communication_groups { - group_ids: 3 - group_ids: 3 + group_ids: 0 + group_ids: 0 } } } } instructions { - name: "all-gather_dim-conc_bidir-ring_root_2" + name: "all-gather_stage-1_dim-0_bidir-ring_root_2" opcode: "null" - instruction_id: 13 - operand_ids: 9 - operand_ids: 11 + instruction_id: 32 + operand_ids: 28 + operand_ids: 30 } } } instructions { - name: "all-gather_dim-0" + name: "all-gather_stage-1_dim-1" opcode: "all-gather" - instruction_id: 14 - bytes_out: 20 + instruction_id: 33 + bytes_out: 40 communication_groups { - group_ids: 0 group_ids: 2 + group_ids: 6 } - operand_ids: 4 + operand_ids: 23 inner_subroutines { - name: "all-gather_dim-0_bidir-ring" - subroutine_root_id: 22 + name: "all-gather_stage-1_dim-1_bidir-ring" + subroutine_root_id: 42 execution_probability: 1 execution_count: 1 instructions { - name: "all-gather_dim-0_bidir-ring_barrier" + name: "all-gather_stage-1_dim-1_bidir-ring_barrier" opcode: "barrier" - instruction_id: 15 + instruction_id: 34 communication_groups { - group_ids: 0 group_ids: 2 + group_ids: 6 } inner_subroutines { - name: "all-gather_dim-0_bidir-ring_barrier_centralized" - subroutine_root_id: 17 + name: "all-gather_stage-1_dim-1_bidir-ring_barrier_centralized" + subroutine_root_id: 37 execution_probability: 1 execution_count: 1 instructions { - name: "all-gather_dim-0_bidir-ring_barrier_centralized_send_to_0" - opcode: "send" - instruction_id: 16 + name: "all-gather_stage-1_dim-1_bidir-ring_barrier_centralized_coordinator_recv_from_6" + opcode: "recv" + instruction_id: 35 communication_groups { - group_ids: 0 + group_ids: 6 } } instructions { - name: "all-gather_dim-0_bidir-ring_barrier_centralized_recv_from_0" - opcode: "recv" - instruction_id: 17 + name: "all-gather_stage-1_dim-1_bidir-ring_barrier_centralized_coordinator_send_to_6" + opcode: "send" + instruction_id: 36 communication_groups { - group_ids: 0 + group_ids: 6 } - operand_ids: 16 + operand_ids: 35 + } + instructions { + name: "all-gather_stage-1_dim-1_bidir-ring_barrier_centralized_root_2" + opcode: "null" + instruction_id: 37 + operand_ids: 36 } } } instructions { - name: "all-gather_dim-0_bidir-ring_cw" + name: "all-gather_stage-1_dim-1_bidir-ring_cw" opcode: "all-gather" - instruction_id: 18 - bytes_out: 10 + instruction_id: 38 + bytes_out: 20 communication_groups { - group_ids: 0 group_ids: 2 + group_ids: 6 } - operand_ids: 15 + operand_ids: 34 inner_subroutines { - name: "all-gather_dim-0_bidir-ring_cw_unidir-ring" - subroutine_root_id: 19 + name: "all-gather_stage-1_dim-1_bidir-ring_cw_unidir-ring" + subroutine_root_id: 39 execution_probability: 1 execution_count: 1 instructions { - name: "all-gather_dim-0_bidir-ring_cw_unidir-ring_sendrecv_1" + name: "all-gather_stage-1_dim-1_bidir-ring_cw_unidir-ring_sendrecv_1" opcode: "sendrecv" - instruction_id: 19 - bytes_in: 5 - bytes_out: 5 + instruction_id: 39 + bytes_in: 10 + bytes_out: 10 communication_groups { - group_ids: 0 - group_ids: 0 + group_ids: 6 + group_ids: 6 } } } } instructions { - name: "all-gather_dim-0_bidir-ring_ccw" + name: "all-gather_stage-1_dim-1_bidir-ring_ccw" opcode: "all-gather" - instruction_id: 20 - bytes_out: 10 + instruction_id: 40 + bytes_out: 20 communication_groups { + group_ids: 6 group_ids: 2 - group_ids: 0 } - operand_ids: 15 + operand_ids: 34 inner_subroutines { - name: "all-gather_dim-0_bidir-ring_ccw_unidir-ring" - subroutine_root_id: 21 + name: "all-gather_stage-1_dim-1_bidir-ring_ccw_unidir-ring" + subroutine_root_id: 41 execution_probability: 1 execution_count: 1 instructions { - name: "all-gather_dim-0_bidir-ring_ccw_unidir-ring_sendrecv_1" + name: "all-gather_stage-1_dim-1_bidir-ring_ccw_unidir-ring_sendrecv_1" opcode: "sendrecv" - instruction_id: 21 - bytes_in: 5 - bytes_out: 5 + instruction_id: 41 + bytes_in: 10 + bytes_out: 10 communication_groups { - group_ids: 0 - group_ids: 0 + group_ids: 6 + group_ids: 6 } } } } instructions { - name: "all-gather_dim-0_bidir-ring_root_2" + name: "all-gather_stage-1_dim-1_bidir-ring_root_2" opcode: "null" - instruction_id: 22 - operand_ids: 18 - operand_ids: 20 + instruction_id: 42 + operand_ids: 38 + operand_ids: 40 } } } instructions { - name: "all-gather_dim-1" + name: "all-gather_stage-1_root" + opcode: "null" + instruction_id: 43 + operand_ids: 24 + operand_ids: 33 + } + instructions { + name: "all-gather_conc" opcode: "all-gather" - instruction_id: 23 - bytes_out: 20 + instruction_id: 44 + bytes_out: 80 communication_groups { group_ids: 2 - group_ids: 6 + group_ids: 3 } - operand_ids: 14 + operand_ids: 43 inner_subroutines { - name: "all-gather_dim-1_bidir-ring" - subroutine_root_id: 32 + name: "all-gather_conc_bidir-ring" + subroutine_root_id: 53 execution_probability: 1 execution_count: 1 instructions { - name: "all-gather_dim-1_bidir-ring_barrier" + name: "all-gather_conc_bidir-ring_barrier" opcode: "barrier" - instruction_id: 24 + instruction_id: 45 communication_groups { group_ids: 2 - group_ids: 6 + group_ids: 3 } inner_subroutines { - name: "all-gather_dim-1_bidir-ring_barrier_centralized" - subroutine_root_id: 27 + name: "all-gather_conc_bidir-ring_barrier_centralized" + subroutine_root_id: 48 execution_probability: 1 execution_count: 1 instructions { - name: "all-gather_dim-1_bidir-ring_barrier_centralized_coordinator_recv_from_6" + name: "all-gather_conc_bidir-ring_barrier_centralized_coordinator_recv_from_3" opcode: "recv" - instruction_id: 25 + instruction_id: 46 communication_groups { - group_ids: 6 + group_ids: 3 } } instructions { - name: "all-gather_dim-1_bidir-ring_barrier_centralized_coordinator_send_to_6" + name: "all-gather_conc_bidir-ring_barrier_centralized_coordinator_send_to_3" opcode: "send" - instruction_id: 26 + instruction_id: 47 communication_groups { - group_ids: 6 + group_ids: 3 } - operand_ids: 25 + operand_ids: 46 } instructions { - name: "all-gather_dim-1_bidir-ring_barrier_centralized_root_2" + name: "all-gather_conc_bidir-ring_barrier_centralized_root_2" opcode: "null" - instruction_id: 27 - operand_ids: 26 + instruction_id: 48 + operand_ids: 47 } } } instructions { - name: "all-gather_dim-1_bidir-ring_cw" + name: "all-gather_conc_bidir-ring_cw" opcode: "all-gather" - instruction_id: 28 - bytes_out: 10 + instruction_id: 49 + bytes_out: 40 communication_groups { group_ids: 2 - group_ids: 6 + group_ids: 3 } - operand_ids: 24 + operand_ids: 45 inner_subroutines { - name: "all-gather_dim-1_bidir-ring_cw_unidir-ring" - subroutine_root_id: 29 + name: "all-gather_conc_bidir-ring_cw_unidir-ring" + subroutine_root_id: 50 execution_probability: 1 execution_count: 1 instructions { - name: "all-gather_dim-1_bidir-ring_cw_unidir-ring_sendrecv_1" + name: "all-gather_conc_bidir-ring_cw_unidir-ring_sendrecv_1" opcode: "sendrecv" - instruction_id: 29 - bytes_in: 5 - bytes_out: 5 + instruction_id: 50 + bytes_in: 20 + bytes_out: 20 communication_groups { - group_ids: 6 - group_ids: 6 + group_ids: 3 + group_ids: 3 } } } } instructions { - name: "all-gather_dim-1_bidir-ring_ccw" + name: "all-gather_conc_bidir-ring_ccw" opcode: "all-gather" - instruction_id: 30 - bytes_out: 10 + instruction_id: 51 + bytes_out: 40 communication_groups { - group_ids: 6 + group_ids: 3 group_ids: 2 } - operand_ids: 24 + operand_ids: 45 inner_subroutines { - name: "all-gather_dim-1_bidir-ring_ccw_unidir-ring" - subroutine_root_id: 31 + name: "all-gather_conc_bidir-ring_ccw_unidir-ring" + subroutine_root_id: 52 execution_probability: 1 execution_count: 1 instructions { - name: "all-gather_dim-1_bidir-ring_ccw_unidir-ring_sendrecv_1" + name: "all-gather_conc_bidir-ring_ccw_unidir-ring_sendrecv_1" opcode: "sendrecv" - instruction_id: 31 - bytes_in: 5 - bytes_out: 5 + instruction_id: 52 + bytes_in: 20 + bytes_out: 20 communication_groups { - group_ids: 6 - group_ids: 6 + group_ids: 3 + group_ids: 3 } } } } instructions { - name: "all-gather_dim-1_bidir-ring_root_2" + name: "all-gather_conc_bidir-ring_root_2" opcode: "null" - instruction_id: 32 - operand_ids: 28 - operand_ids: 30 + instruction_id: 53 + operand_ids: 49 + operand_ids: 51 } } } } - )proto"; - google::protobuf::TextFormat::ParseFromString(allgather_str, - &allgather_proto); - EXPECT_TRUE(google::protobuf::util::MessageDifferencer::Equals( - allgather->ToProto().value(), allgather_proto)); -} + )proto"; + google::protobuf::TextFormat::ParseFromString(test_str, + &proto); + return proto; +} // NOLINT -// Tests expanding 1D-Torus all-gather -TEST(Torus2dAllGather, InconsecutiveProcessors) { +// Tests expanding 1D-Torus all-gather with barrier +TEST(Torus2dAllGather, WithBarrier) { auto graph = absl::make_unique("test_graph", 2); auto sub = absl::make_unique( "test_subroutine", graph.get()); auto sub_ptr = sub.get(); + sub_ptr->SetId(3); graph->SetEntrySubroutine(std::move(sub)); ASSERT_OK_AND_ASSIGN(auto instr_1, paragraph::Instruction::Create( @@ -757,8 +1273,8 @@ TEST(Torus2dAllGather, InconsecutiveProcessors) { ASSERT_OK_AND_ASSIGN(auto allgather, paragraph::Instruction::Create( paragraph::Opcode::kAllGather, "all-gather", sub_ptr)); - allgather->SetBytesOut(48); - paragraph::CommunicationGroup allgather_group = {0, 2, 4}; + allgather->SetBytesOut(80); + paragraph::CommunicationGroup allgather_group = {0, 1, 2, 3, 4, 5, 6, 7}; allgather->AppendCommunicationGroup(allgather_group); ASSERT_OK_AND_ASSIGN(auto instr_3, paragraph::Instruction::Create( @@ -769,7 +1285,11 @@ TEST(Torus2dAllGather, InconsecutiveProcessors) { { "all-gather": { "algorithm": "torus-2d", - "dimension_widths": [2, 3] + "concentration": 2, + "dimension_widths": [2, 2], + "barrier": { + "algorithm": "centralized" + } } } )"_json; @@ -778,8 +1298,14 @@ TEST(Torus2dAllGather, InconsecutiveProcessors) { paragraph::TranslatorType::kCollective, config)); EXPECT_OK(translators["all-gather"]->Translate(allgather)); - paragraph::InstructionProto allgather_proto; - std::string allgather_str = + paragraph::InstructionProto allgather_proto = with_barrier_test_proto(); + EXPECT_TRUE(google::protobuf::util::MessageDifferencer::Equals( + allgather->ToProto().value(), allgather_proto)); +} + +paragraph::InstructionProto inconsecutive_proc_test_proto() { + paragraph::InstructionProto proto; + std::string test_str = R"proto( name: "all-gather" opcode: "all-gather" @@ -792,11 +1318,11 @@ communication_groups { } inner_subroutines { name: "all-gather_torus-2d" - subroutine_root_id: 4 + subroutine_root_id: 21 execution_probability: 1 execution_count: 1 instructions { - name: "all-gather_dim-1" + name: "all-gather_stage-0_dim-1" opcode: "all-gather" instruction_id: 4 bytes_out: 48 @@ -806,12 +1332,12 @@ inner_subroutines { group_ids: 4 } inner_subroutines { - name: "all-gather_dim-1_bidir-ring" + name: "all-gather_stage-0_dim-1_bidir-ring" subroutine_root_id: 11 execution_probability: 1 execution_count: 1 instructions { - name: "all-gather_dim-1_bidir-ring_cw" + name: "all-gather_stage-0_dim-1_bidir-ring_cw" opcode: "all-gather" instruction_id: 5 bytes_out: 24 @@ -821,12 +1347,12 @@ inner_subroutines { group_ids: 4 } inner_subroutines { - name: "all-gather_dim-1_bidir-ring_cw_unidir-ring" + name: "all-gather_stage-0_dim-1_bidir-ring_cw_unidir-ring" subroutine_root_id: 7 execution_probability: 1 execution_count: 1 instructions { - name: "all-gather_dim-1_bidir-ring_cw_unidir-ring_sendrecv_1" + name: "all-gather_stage-0_dim-1_bidir-ring_cw_unidir-ring_sendrecv_1" opcode: "sendrecv" instruction_id: 6 bytes_in: 8 @@ -837,7 +1363,7 @@ inner_subroutines { } } instructions { - name: "all-gather_dim-1_bidir-ring_cw_unidir-ring_sendrecv_2" + name: "all-gather_stage-0_dim-1_bidir-ring_cw_unidir-ring_sendrecv_2" opcode: "sendrecv" instruction_id: 7 bytes_in: 8 @@ -851,7 +1377,7 @@ inner_subroutines { } } instructions { - name: "all-gather_dim-1_bidir-ring_ccw" + name: "all-gather_stage-0_dim-1_bidir-ring_ccw" opcode: "all-gather" instruction_id: 8 bytes_out: 24 @@ -861,12 +1387,12 @@ inner_subroutines { group_ids: 0 } inner_subroutines { - name: "all-gather_dim-1_bidir-ring_ccw_unidir-ring" + name: "all-gather_stage-0_dim-1_bidir-ring_ccw_unidir-ring" subroutine_root_id: 10 execution_probability: 1 execution_count: 1 instructions { - name: "all-gather_dim-1_bidir-ring_ccw_unidir-ring_sendrecv_1" + name: "all-gather_stage-0_dim-1_bidir-ring_ccw_unidir-ring_sendrecv_1" opcode: "sendrecv" instruction_id: 9 bytes_in: 8 @@ -877,7 +1403,7 @@ inner_subroutines { } } instructions { - name: "all-gather_dim-1_bidir-ring_ccw_unidir-ring_sendrecv_2" + name: "all-gather_stage-0_dim-1_bidir-ring_ccw_unidir-ring_sendrecv_2" opcode: "sendrecv" instruction_id: 10 bytes_in: 8 @@ -891,7 +1417,7 @@ inner_subroutines { } } instructions { - name: "all-gather_dim-1_bidir-ring_root_2" + name: "all-gather_stage-0_dim-1_bidir-ring_root_2" opcode: "null" instruction_id: 11 operand_ids: 5 @@ -899,10 +1425,168 @@ inner_subroutines { } } } + instructions { + name: "all-gather_stage-0_root" + opcode: "null" + instruction_id: 12 + operand_ids: 4 + } + instructions { + name: "all-gather_stage-1_dim-1" + opcode: "all-gather" + instruction_id: 13 + bytes_out: 144 + communication_groups { + group_ids: 0 + group_ids: 2 + group_ids: 4 + } + operand_ids: 12 + inner_subroutines { + name: "all-gather_stage-1_dim-1_bidir-ring" + subroutine_root_id: 20 + execution_probability: 1 + execution_count: 1 + instructions { + name: "all-gather_stage-1_dim-1_bidir-ring_cw" + opcode: "all-gather" + instruction_id: 14 + bytes_out: 72 + communication_groups { + group_ids: 0 + group_ids: 2 + group_ids: 4 + } + inner_subroutines { + name: "all-gather_stage-1_dim-1_bidir-ring_cw_unidir-ring" + subroutine_root_id: 16 + execution_probability: 1 + execution_count: 1 + instructions { + name: "all-gather_stage-1_dim-1_bidir-ring_cw_unidir-ring_sendrecv_1" + opcode: "sendrecv" + instruction_id: 15 + bytes_in: 24 + bytes_out: 24 + communication_groups { + group_ids: 0 + group_ids: 4 + } + } + instructions { + name: "all-gather_stage-1_dim-1_bidir-ring_cw_unidir-ring_sendrecv_2" + opcode: "sendrecv" + instruction_id: 16 + bytes_in: 24 + bytes_out: 24 + communication_groups { + group_ids: 0 + group_ids: 4 + } + operand_ids: 15 + } + } + } + instructions { + name: "all-gather_stage-1_dim-1_bidir-ring_ccw" + opcode: "all-gather" + instruction_id: 17 + bytes_out: 72 + communication_groups { + group_ids: 4 + group_ids: 2 + group_ids: 0 + } + inner_subroutines { + name: "all-gather_stage-1_dim-1_bidir-ring_ccw_unidir-ring" + subroutine_root_id: 19 + execution_probability: 1 + execution_count: 1 + instructions { + name: "all-gather_stage-1_dim-1_bidir-ring_ccw_unidir-ring_sendrecv_1" + opcode: "sendrecv" + instruction_id: 18 + bytes_in: 24 + bytes_out: 24 + communication_groups { + group_ids: 4 + group_ids: 0 + } + } + instructions { + name: "all-gather_stage-1_dim-1_bidir-ring_ccw_unidir-ring_sendrecv_2" + opcode: "sendrecv" + instruction_id: 19 + bytes_in: 24 + bytes_out: 24 + communication_groups { + group_ids: 4 + group_ids: 0 + } + operand_ids: 18 + } + } + } + instructions { + name: "all-gather_stage-1_dim-1_bidir-ring_root_2" + opcode: "null" + instruction_id: 20 + operand_ids: 14 + operand_ids: 17 + } + } + } + instructions { + name: "all-gather_stage-1_root" + opcode: "null" + instruction_id: 21 + operand_ids: 13 + } } - )proto"; - google::protobuf::TextFormat::ParseFromString(allgather_str, - &allgather_proto); + )proto"; + google::protobuf::TextFormat::ParseFromString(test_str, + &proto); + return proto; +} // NOLINT + +// Tests expanding 1D-Torus all-gather +TEST(Torus2dAllGather, InconsecutiveProcessors) { + auto graph = absl::make_unique("test_graph", 2); + auto sub = absl::make_unique( + "test_subroutine", graph.get()); + auto sub_ptr = sub.get(); + graph->SetEntrySubroutine(std::move(sub)); + + ASSERT_OK_AND_ASSIGN(auto instr_1, paragraph::Instruction::Create( + paragraph::Opcode::kDelay, "first_instruction", sub_ptr)); + instr_1->SetOps(4); + + ASSERT_OK_AND_ASSIGN(auto allgather, + paragraph::Instruction::Create( + paragraph::Opcode::kAllGather, "all-gather", sub_ptr)); + allgather->SetBytesOut(48); + paragraph::CommunicationGroup allgather_group = {0, 2, 4}; + allgather->AppendCommunicationGroup(allgather_group); + + ASSERT_OK_AND_ASSIGN(auto instr_3, paragraph::Instruction::Create( + paragraph::Opcode::kDelay, "last_instruction", sub_ptr, true)); + instr_3->SetOps(4); + + nlohmann::json config = R"( + { + "all-gather": { + "algorithm": "torus-2d", + "dimension_widths": [2, 3] + } + } + )"_json; + + ASSERT_OK_AND_ASSIGN(auto translators, paragraph::CreateTranslators( + paragraph::TranslatorType::kCollective, config)); + EXPECT_OK(translators["all-gather"]->Translate(allgather)); + + paragraph::InstructionProto allgather_proto = + inconsecutive_proc_test_proto(); EXPECT_TRUE(google::protobuf::util::MessageDifferencer::Equals( allgather->ToProto().value(), allgather_proto)); } diff --git a/paragraph/translation/allreduce/mesh_2d_allreduce_translator_test.cc b/paragraph/translation/allreduce/mesh_2d_allreduce_translator_test.cc index b7f8568..8c4ad67 100644 --- a/paragraph/translation/allreduce/mesh_2d_allreduce_translator_test.cc +++ b/paragraph/translation/allreduce/mesh_2d_allreduce_translator_test.cc @@ -24,61 +24,9 @@ #include "paragraph/shim/test_macros.h" #include "paragraph/translation/translation_map.h" -// Tests expanding 2D-Mesh all-reduce -TEST(Mesh2dAllReduce, NoBarrier) { - auto graph = absl::make_unique("test_graph", 2); - auto sub = absl::make_unique( - "test_subroutine", graph.get()); - auto sub_ptr = sub.get(); - sub_ptr->SetId(3); - graph->SetEntrySubroutine(std::move(sub)); - - ASSERT_OK_AND_ASSIGN(auto instr_1, paragraph::Instruction::Create( - paragraph::Opcode::kDelay, "first_instruction", sub_ptr)); - instr_1->SetOps(4); - - ASSERT_OK_AND_ASSIGN(auto allreduce, paragraph::Instruction::Create( - paragraph::Opcode::kAllReduce, "all-reduce", sub_ptr)); - allreduce->SetBytesOut(48); - paragraph::CommunicationGroup allreduce_group = {0, 1, 2, 3, 4, 5, 6, 7}; - allreduce->AppendCommunicationGroup(allreduce_group); - - auto reduction_sub = absl::make_unique( - "reduction_subroutine", graph.get()); - auto reduction_ptr = reduction_sub.get(); - ASSERT_OK_AND_ASSIGN(auto op1, paragraph::Instruction::Create( - paragraph::Opcode::kDelay, "op1", reduction_ptr)); - op1->SetBytesOut(48); - ASSERT_OK_AND_ASSIGN(auto op2, paragraph::Instruction::Create( - paragraph::Opcode::kDelay, "op2", reduction_ptr)); - op2->SetBytesOut(48); - ASSERT_OK_AND_ASSIGN(auto sum_op, paragraph::Instruction::Create( - paragraph::Opcode::kDelay, "sum", reduction_ptr, true)); - sum_op->SetOps(96); - sum_op->AddOperand(op1); - sum_op->AddOperand(op2); - allreduce->AppendInnerSubroutine(std::move(reduction_sub)); - - ASSERT_OK_AND_ASSIGN(auto instr_3, paragraph::Instruction::Create( - paragraph::Opcode::kDelay, "last_instruction", sub_ptr, true)); - instr_3->SetOps(4); - - nlohmann::json config = R"( - { - "all-reduce": { - "algorithm": "mesh-2d", - "concentration": 2, - "dimension_widths": [2, 2] - } - } - )"_json; - - ASSERT_OK_AND_ASSIGN(auto translators, paragraph::CreateTranslators( - paragraph::TranslatorType::kCollective, config)); - EXPECT_OK(translators["all-reduce"]->Translate(allreduce)); - - paragraph::InstructionProto allreduce_proto; - std::string allreduce_str = +paragraph::InstructionProto Mesh2dAllReduce_no_barrier_test_proto() { + paragraph::InstructionProto proto; + std::string test_str = R"proto( name: "all-reduce" opcode: "all-reduce" @@ -96,7 +44,7 @@ communication_groups { } inner_subroutines { name: "all-reduce_mesh-2d" - subroutine_root_id: 26 + subroutine_root_id: 40 execution_probability: 1 execution_count: 1 instructions { @@ -116,58 +64,58 @@ inner_subroutines { } inner_subroutines { name: "all-reduce_mesh-2d_reduce-scatter_mesh-2d" - subroutine_root_id: 20 + subroutine_root_id: 34 execution_probability: 1 execution_count: 1 instructions { - name: "all-reduce_mesh-2d_reduce-scatter_dim-conc" + name: "all-reduce_mesh-2d_reduce-scatter_stage-0_dim-0" opcode: "reduce-scatter" instruction_id: 8 bytes_out: 12 communication_groups { + group_ids: 0 group_ids: 2 - group_ids: 3 } inner_subroutines { - name: "all-reduce_mesh-2d_reduce-scatter_dim-conc_mesh-1d" + name: "all-reduce_mesh-2d_reduce-scatter_stage-0_dim-0_mesh-1d" subroutine_root_id: 10 execution_probability: 1 execution_count: 1 instructions { - name: "all-reduce_mesh-2d_reduce-scatter_dim-conc_mesh-1d_cw_sendrecv_0" + name: "all-reduce_mesh-2d_reduce-scatter_stage-0_dim-0_mesh-1d_ccw_sendrecv_0" opcode: "sendrecv" instruction_id: 9 bytes_in: 6 bytes_out: 6 communication_groups { - group_ids: 3 - group_ids: 3 + group_ids: 0 + group_ids: 0 } } instructions { - name: "all-reduce_mesh-2d_reduce-scatter_dim-conc_mesh-1d_cw_reduction_0" + name: "all-reduce_mesh-2d_reduce-scatter_stage-0_dim-0_mesh-1d_ccw_reduction_0" opcode: "call" instruction_id: 10 operand_ids: 9 inner_subroutines { - name: "reduction_subroutine_cw_phase_0" + name: "reduction_subroutine_ccw_phase_0" subroutine_root_id: 13 execution_probability: 1 execution_count: 1 instructions { - name: "op1_cw_phase_0" + name: "op1_ccw_phase_0" opcode: "delay" instruction_id: 11 bytes_out: 6 } instructions { - name: "op2_cw_phase_0" + name: "op2_ccw_phase_0" opcode: "delay" instruction_id: 12 bytes_out: 6 } instructions { - name: "sum_cw_phase_0" + name: "sum_ccw_phase_0" opcode: "delay" instruction_id: 13 ops: 12 @@ -179,55 +127,54 @@ inner_subroutines { } } instructions { - name: "all-reduce_mesh-2d_reduce-scatter_dim-0" + name: "all-reduce_mesh-2d_reduce-scatter_stage-0_dim-1" opcode: "reduce-scatter" instruction_id: 14 bytes_out: 12 communication_groups { - group_ids: 0 group_ids: 2 + group_ids: 6 } - operand_ids: 8 inner_subroutines { - name: "all-reduce_mesh-2d_reduce-scatter_dim-0_mesh-1d" + name: "all-reduce_mesh-2d_reduce-scatter_stage-0_dim-1_mesh-1d" subroutine_root_id: 16 execution_probability: 1 execution_count: 1 instructions { - name: "all-reduce_mesh-2d_reduce-scatter_dim-0_mesh-1d_ccw_sendrecv_0" + name: "all-reduce_mesh-2d_reduce-scatter_stage-0_dim-1_mesh-1d_cw_sendrecv_0" opcode: "sendrecv" instruction_id: 15 bytes_in: 6 bytes_out: 6 communication_groups { - group_ids: 0 - group_ids: 0 + group_ids: 6 + group_ids: 6 } } instructions { - name: "all-reduce_mesh-2d_reduce-scatter_dim-0_mesh-1d_ccw_reduction_0" + name: "all-reduce_mesh-2d_reduce-scatter_stage-0_dim-1_mesh-1d_cw_reduction_0" opcode: "call" instruction_id: 16 operand_ids: 15 inner_subroutines { - name: "reduction_subroutine_ccw_phase_0" + name: "reduction_subroutine_cw_phase_0" subroutine_root_id: 19 execution_probability: 1 execution_count: 1 instructions { - name: "op1_ccw_phase_0" + name: "op1_cw_phase_0" opcode: "delay" instruction_id: 17 bytes_out: 6 } instructions { - name: "op2_ccw_phase_0" + name: "op2_cw_phase_0" opcode: "delay" instruction_id: 18 bytes_out: 6 } instructions { - name: "sum_ccw_phase_0" + name: "sum_cw_phase_0" opcode: "delay" instruction_id: 19 ops: 12 @@ -239,60 +186,194 @@ inner_subroutines { } } instructions { - name: "all-reduce_mesh-2d_reduce-scatter_dim-1" - opcode: "reduce-scatter" + name: "all-reduce_mesh-2d_reduce-scatter_stage-0_root" + opcode: "null" instruction_id: 20 - bytes_out: 12 + operand_ids: 8 + operand_ids: 14 + } + instructions { + name: "all-reduce_mesh-2d_reduce-scatter_stage-1_dim-0" + opcode: "reduce-scatter" + instruction_id: 21 + bytes_out: 24 + communication_groups { + group_ids: 0 + group_ids: 2 + } + operand_ids: 20 + inner_subroutines { + name: "all-reduce_mesh-2d_reduce-scatter_stage-1_dim-0_mesh-1d" + subroutine_root_id: 23 + execution_probability: 1 + execution_count: 1 + instructions { + name: "all-reduce_mesh-2d_reduce-scatter_stage-1_dim-0_mesh-1d_ccw_sendrecv_0" + opcode: "sendrecv" + instruction_id: 22 + bytes_in: 12 + bytes_out: 12 + communication_groups { + group_ids: 0 + group_ids: 0 + } + } + instructions { + name: "all-reduce_mesh-2d_reduce-scatter_stage-1_dim-0_mesh-1d_ccw_reduction_0" + opcode: "call" + instruction_id: 23 + operand_ids: 22 + inner_subroutines { + name: "reduction_subroutine_ccw_phase_0" + subroutine_root_id: 26 + execution_probability: 1 + execution_count: 1 + instructions { + name: "op1_ccw_phase_0" + opcode: "delay" + instruction_id: 24 + bytes_out: 12 + } + instructions { + name: "op2_ccw_phase_0" + opcode: "delay" + instruction_id: 25 + bytes_out: 12 + } + instructions { + name: "sum_ccw_phase_0" + opcode: "delay" + instruction_id: 26 + ops: 24 + operand_ids: 24 + operand_ids: 25 + } + } + } + } + } + instructions { + name: "all-reduce_mesh-2d_reduce-scatter_stage-1_dim-1" + opcode: "reduce-scatter" + instruction_id: 27 + bytes_out: 24 communication_groups { group_ids: 2 group_ids: 6 } - operand_ids: 14 + operand_ids: 20 inner_subroutines { - name: "all-reduce_mesh-2d_reduce-scatter_dim-1_mesh-1d" - subroutine_root_id: 22 + name: "all-reduce_mesh-2d_reduce-scatter_stage-1_dim-1_mesh-1d" + subroutine_root_id: 29 execution_probability: 1 execution_count: 1 instructions { - name: "all-reduce_mesh-2d_reduce-scatter_dim-1_mesh-1d_cw_sendrecv_0" + name: "all-reduce_mesh-2d_reduce-scatter_stage-1_dim-1_mesh-1d_cw_sendrecv_0" opcode: "sendrecv" - instruction_id: 21 - bytes_in: 6 - bytes_out: 6 + instruction_id: 28 + bytes_in: 12 + bytes_out: 12 communication_groups { group_ids: 6 group_ids: 6 } } instructions { - name: "all-reduce_mesh-2d_reduce-scatter_dim-1_mesh-1d_cw_reduction_0" + name: "all-reduce_mesh-2d_reduce-scatter_stage-1_dim-1_mesh-1d_cw_reduction_0" opcode: "call" - instruction_id: 22 - operand_ids: 21 + instruction_id: 29 + operand_ids: 28 inner_subroutines { name: "reduction_subroutine_cw_phase_0" - subroutine_root_id: 25 + subroutine_root_id: 32 execution_probability: 1 execution_count: 1 instructions { name: "op1_cw_phase_0" opcode: "delay" - instruction_id: 23 - bytes_out: 6 + instruction_id: 30 + bytes_out: 12 } instructions { name: "op2_cw_phase_0" opcode: "delay" - instruction_id: 24 - bytes_out: 6 + instruction_id: 31 + bytes_out: 12 } instructions { name: "sum_cw_phase_0" opcode: "delay" - instruction_id: 25 - ops: 12 - operand_ids: 23 - operand_ids: 24 + instruction_id: 32 + ops: 24 + operand_ids: 30 + operand_ids: 31 + } + } + } + } + } + instructions { + name: "all-reduce_mesh-2d_reduce-scatter_stage-1_root" + opcode: "null" + instruction_id: 33 + operand_ids: 21 + operand_ids: 27 + } + instructions { + name: "all-reduce_mesh-2d_reduce-scatter_conc" + opcode: "reduce-scatter" + instruction_id: 34 + bytes_out: 48 + communication_groups { + group_ids: 2 + group_ids: 3 + } + operand_ids: 33 + inner_subroutines { + name: "all-reduce_mesh-2d_reduce-scatter_conc_mesh-1d" + subroutine_root_id: 36 + execution_probability: 1 + execution_count: 1 + instructions { + name: "all-reduce_mesh-2d_reduce-scatter_conc_mesh-1d_cw_sendrecv_0" + opcode: "sendrecv" + instruction_id: 35 + bytes_in: 24 + bytes_out: 24 + communication_groups { + group_ids: 3 + group_ids: 3 + } + } + instructions { + name: "all-reduce_mesh-2d_reduce-scatter_conc_mesh-1d_cw_reduction_0" + opcode: "call" + instruction_id: 36 + operand_ids: 35 + inner_subroutines { + name: "reduction_subroutine_cw_phase_0" + subroutine_root_id: 39 + execution_probability: 1 + execution_count: 1 + instructions { + name: "op1_cw_phase_0" + opcode: "delay" + instruction_id: 37 + bytes_out: 24 + } + instructions { + name: "op2_cw_phase_0" + opcode: "delay" + instruction_id: 38 + bytes_out: 24 + } + instructions { + name: "sum_cw_phase_0" + opcode: "delay" + instruction_id: 39 + ops: 48 + operand_ids: 37 + operand_ids: 38 } } } @@ -303,7 +384,7 @@ inner_subroutines { instructions { name: "all-reduce_mesh-2d_all-gather" opcode: "all-gather" - instruction_id: 26 + instruction_id: 40 bytes_out: 48 communication_groups { group_ids: 0 @@ -318,57 +399,91 @@ inner_subroutines { operand_ids: 7 inner_subroutines { name: "all-reduce_mesh-2d_all-gather_mesh-2d" - subroutine_root_id: 31 + subroutine_root_id: 51 execution_probability: 1 execution_count: 1 instructions { - name: "all-reduce_mesh-2d_all-gather_dim-conc" + name: "all-reduce_mesh-2d_all-gather_stage-0_dim-0" opcode: "all-gather" - instruction_id: 27 + instruction_id: 41 bytes_out: 12 communication_groups { + group_ids: 0 group_ids: 2 - group_ids: 3 } inner_subroutines { - name: "all-reduce_mesh-2d_all-gather_dim-conc_mesh-1d" - subroutine_root_id: 28 + name: "all-reduce_mesh-2d_all-gather_stage-0_dim-0_mesh-1d" + subroutine_root_id: 42 execution_probability: 1 execution_count: 1 instructions { - name: "all-reduce_mesh-2d_all-gather_dim-conc_mesh-1d_cw_sendrecv_0" + name: "all-reduce_mesh-2d_all-gather_stage-0_dim-0_mesh-1d_ccw_sendrecv_0" opcode: "sendrecv" - instruction_id: 28 + instruction_id: 42 bytes_in: 6 bytes_out: 6 communication_groups { - group_ids: 3 - group_ids: 3 + group_ids: 0 + group_ids: 0 } } } } instructions { - name: "all-reduce_mesh-2d_all-gather_dim-0" + name: "all-reduce_mesh-2d_all-gather_stage-0_dim-1" opcode: "all-gather" - instruction_id: 29 + instruction_id: 43 bytes_out: 12 communication_groups { - group_ids: 0 group_ids: 2 + group_ids: 6 } - operand_ids: 27 inner_subroutines { - name: "all-reduce_mesh-2d_all-gather_dim-0_mesh-1d" - subroutine_root_id: 30 + name: "all-reduce_mesh-2d_all-gather_stage-0_dim-1_mesh-1d" + subroutine_root_id: 44 execution_probability: 1 execution_count: 1 instructions { - name: "all-reduce_mesh-2d_all-gather_dim-0_mesh-1d_ccw_sendrecv_0" + name: "all-reduce_mesh-2d_all-gather_stage-0_dim-1_mesh-1d_cw_sendrecv_0" opcode: "sendrecv" - instruction_id: 30 + instruction_id: 44 bytes_in: 6 bytes_out: 6 + communication_groups { + group_ids: 6 + group_ids: 6 + } + } + } + } + instructions { + name: "all-reduce_mesh-2d_all-gather_stage-0_root" + opcode: "null" + instruction_id: 45 + operand_ids: 41 + operand_ids: 43 + } + instructions { + name: "all-reduce_mesh-2d_all-gather_stage-1_dim-0" + opcode: "all-gather" + instruction_id: 46 + bytes_out: 24 + communication_groups { + group_ids: 0 + group_ids: 2 + } + operand_ids: 45 + inner_subroutines { + name: "all-reduce_mesh-2d_all-gather_stage-1_dim-0_mesh-1d" + subroutine_root_id: 47 + execution_probability: 1 + execution_count: 1 + instructions { + name: "all-reduce_mesh-2d_all-gather_stage-1_dim-0_mesh-1d_ccw_sendrecv_0" + opcode: "sendrecv" + instruction_id: 47 + bytes_in: 12 + bytes_out: 12 communication_groups { group_ids: 0 group_ids: 0 @@ -377,26 +492,26 @@ inner_subroutines { } } instructions { - name: "all-reduce_mesh-2d_all-gather_dim-1" + name: "all-reduce_mesh-2d_all-gather_stage-1_dim-1" opcode: "all-gather" - instruction_id: 31 - bytes_out: 12 + instruction_id: 48 + bytes_out: 24 communication_groups { group_ids: 2 group_ids: 6 } - operand_ids: 29 + operand_ids: 45 inner_subroutines { - name: "all-reduce_mesh-2d_all-gather_dim-1_mesh-1d" - subroutine_root_id: 32 + name: "all-reduce_mesh-2d_all-gather_stage-1_dim-1_mesh-1d" + subroutine_root_id: 49 execution_probability: 1 execution_count: 1 instructions { - name: "all-reduce_mesh-2d_all-gather_dim-1_mesh-1d_cw_sendrecv_0" + name: "all-reduce_mesh-2d_all-gather_stage-1_dim-1_mesh-1d_cw_sendrecv_0" opcode: "sendrecv" - instruction_id: 32 - bytes_in: 6 - bytes_out: 6 + instruction_id: 49 + bytes_in: 12 + bytes_out: 12 communication_groups { group_ids: 6 group_ids: 6 @@ -404,18 +519,52 @@ inner_subroutines { } } } + instructions { + name: "all-reduce_mesh-2d_all-gather_stage-1_root" + opcode: "null" + instruction_id: 50 + operand_ids: 46 + operand_ids: 48 + } + instructions { + name: "all-reduce_mesh-2d_all-gather_conc" + opcode: "all-gather" + instruction_id: 51 + bytes_out: 48 + communication_groups { + group_ids: 2 + group_ids: 3 + } + operand_ids: 50 + inner_subroutines { + name: "all-reduce_mesh-2d_all-gather_conc_mesh-1d" + subroutine_root_id: 52 + execution_probability: 1 + execution_count: 1 + instructions { + name: "all-reduce_mesh-2d_all-gather_conc_mesh-1d_cw_sendrecv_0" + opcode: "sendrecv" + instruction_id: 52 + bytes_in: 24 + bytes_out: 24 + communication_groups { + group_ids: 3 + group_ids: 3 + } + } + } + } } } } )proto"; - google::protobuf::TextFormat::ParseFromString(allreduce_str, - &allreduce_proto); - EXPECT_TRUE(google::protobuf::util::MessageDifferencer::Equals( - allreduce->ToProto().value(), allreduce_proto)); -} + google::protobuf::TextFormat::ParseFromString(test_str, + &proto); + return proto; +} // NOLINT -// Tests expanding 2D-Mesh all-reduce with barrier -TEST(Mesh2dAllReduce, WithBarrier) { +// Tests expanding 2D-Mesh all-reduce +TEST(Mesh2dAllReduce, NoBarrier) { auto graph = absl::make_unique("test_graph", 2); auto sub = absl::make_unique( "test_subroutine", graph.get()); @@ -458,10 +607,7 @@ TEST(Mesh2dAllReduce, WithBarrier) { "all-reduce": { "algorithm": "mesh-2d", "concentration": 2, - "dimension_widths": [2, 2], - "barrier": { - "algorithm": "centralized" - } + "dimension_widths": [2, 2] } } )"_json; @@ -470,8 +616,15 @@ TEST(Mesh2dAllReduce, WithBarrier) { paragraph::TranslatorType::kCollective, config)); EXPECT_OK(translators["all-reduce"]->Translate(allreduce)); - paragraph::InstructionProto allreduce_proto; - std::string allreduce_str = + paragraph::InstructionProto allreduce_proto = + Mesh2dAllReduce_no_barrier_test_proto(); + EXPECT_TRUE(google::protobuf::util::MessageDifferencer::Equals( + allreduce->ToProto().value(), allreduce_proto)); +} + +paragraph::InstructionProto Mesh2dAllReduce_with_barrier_test_proto() { + paragraph::InstructionProto proto; + std::string test_str = R"proto( name: "all-reduce" opcode: "all-reduce" @@ -489,7 +642,7 @@ communication_groups { } inner_subroutines { name: "all-reduce_mesh-2d" - subroutine_root_id: 37 + subroutine_root_id: 58 execution_probability: 1 execution_count: 1 instructions { @@ -509,190 +662,189 @@ inner_subroutines { } inner_subroutines { name: "all-reduce_mesh-2d_reduce-scatter_mesh-2d" - subroutine_root_id: 27 + subroutine_root_id: 48 execution_probability: 1 execution_count: 1 instructions { - name: "all-reduce_mesh-2d_reduce-scatter_dim-conc" + name: "all-reduce_mesh-2d_reduce-scatter_stage-0_dim-0" opcode: "reduce-scatter" instruction_id: 8 bytes_out: 12 communication_groups { + group_ids: 0 group_ids: 2 - group_ids: 3 } inner_subroutines { - name: "all-reduce_mesh-2d_reduce-scatter_dim-conc_mesh-1d" - subroutine_root_id: 14 + name: "all-reduce_mesh-2d_reduce-scatter_stage-0_dim-0_mesh-1d" + subroutine_root_id: 13 execution_probability: 1 execution_count: 1 instructions { - name: "all-reduce_mesh-2d_reduce-scatter_dim-conc_unidir-ring_barrier" + name: "all-reduce_mesh-2d_reduce-scatter_stage-0_dim-0_unidir-ring_barrier" opcode: "barrier" instruction_id: 9 communication_groups { + group_ids: 0 group_ids: 2 - group_ids: 3 } inner_subroutines { - name: "all-reduce_mesh-2d_reduce-scatter_dim-conc_unidir-ring_barrier_centralized" - subroutine_root_id: 12 + name: "all-reduce_mesh-2d_reduce-scatter_stage-0_dim-0_unidir-ring_barrier_centralized" + subroutine_root_id: 11 execution_probability: 1 execution_count: 1 instructions { - name: "all-reduce_mesh-2d_reduce-scatter_dim-conc_unidir-ring_barrier_centralized_coordinator_recv_from_3" - opcode: "recv" + name: "all-reduce_mesh-2d_reduce-scatter_stage-0_dim-0_unidir-ring_barrier_centralized_send_to_0" + opcode: "send" instruction_id: 10 communication_groups { - group_ids: 3 + group_ids: 0 } } instructions { - name: "all-reduce_mesh-2d_reduce-scatter_dim-conc_unidir-ring_barrier_centralized_coordinator_send_to_3" - opcode: "send" + name: "all-reduce_mesh-2d_reduce-scatter_stage-0_dim-0_unidir-ring_barrier_centralized_recv_from_0" + opcode: "recv" instruction_id: 11 communication_groups { - group_ids: 3 + group_ids: 0 } operand_ids: 10 } - instructions { - name: "all-reduce_mesh-2d_reduce-scatter_dim-conc_unidir-ring_barrier_centralized_root_2" - opcode: "null" - instruction_id: 12 - operand_ids: 11 - } } } instructions { - name: "all-reduce_mesh-2d_reduce-scatter_dim-conc_mesh-1d_cw_sendrecv_0" + name: "all-reduce_mesh-2d_reduce-scatter_stage-0_dim-0_mesh-1d_ccw_sendrecv_0" opcode: "sendrecv" - instruction_id: 13 + instruction_id: 12 bytes_in: 6 bytes_out: 6 communication_groups { - group_ids: 3 - group_ids: 3 + group_ids: 0 + group_ids: 0 } operand_ids: 9 } instructions { - name: "all-reduce_mesh-2d_reduce-scatter_dim-conc_mesh-1d_cw_reduction_0" + name: "all-reduce_mesh-2d_reduce-scatter_stage-0_dim-0_mesh-1d_ccw_reduction_0" opcode: "call" - instruction_id: 14 - operand_ids: 13 + instruction_id: 13 + operand_ids: 12 inner_subroutines { - name: "reduction_subroutine_cw_phase_0" - subroutine_root_id: 17 + name: "reduction_subroutine_ccw_phase_0" + subroutine_root_id: 16 execution_probability: 1 execution_count: 1 instructions { - name: "op1_cw_phase_0" + name: "op1_ccw_phase_0" opcode: "delay" - instruction_id: 15 + instruction_id: 14 bytes_out: 6 } instructions { - name: "op2_cw_phase_0" + name: "op2_ccw_phase_0" opcode: "delay" - instruction_id: 16 + instruction_id: 15 bytes_out: 6 } instructions { - name: "sum_cw_phase_0" + name: "sum_ccw_phase_0" opcode: "delay" - instruction_id: 17 + instruction_id: 16 ops: 12 + operand_ids: 14 operand_ids: 15 - operand_ids: 16 } } } } } instructions { - name: "all-reduce_mesh-2d_reduce-scatter_dim-0" + name: "all-reduce_mesh-2d_reduce-scatter_stage-0_dim-1" opcode: "reduce-scatter" - instruction_id: 18 + instruction_id: 17 bytes_out: 12 communication_groups { - group_ids: 0 group_ids: 2 + group_ids: 6 } - operand_ids: 8 inner_subroutines { - name: "all-reduce_mesh-2d_reduce-scatter_dim-0_mesh-1d" + name: "all-reduce_mesh-2d_reduce-scatter_stage-0_dim-1_mesh-1d" subroutine_root_id: 23 execution_probability: 1 execution_count: 1 instructions { - name: "all-reduce_mesh-2d_reduce-scatter_dim-0_unidir-ring_barrier" + name: "all-reduce_mesh-2d_reduce-scatter_stage-0_dim-1_unidir-ring_barrier" opcode: "barrier" - instruction_id: 19 + instruction_id: 18 communication_groups { - group_ids: 0 group_ids: 2 + group_ids: 6 } inner_subroutines { - name: "all-reduce_mesh-2d_reduce-scatter_dim-0_unidir-ring_barrier_centralized" + name: "all-reduce_mesh-2d_reduce-scatter_stage-0_dim-1_unidir-ring_barrier_centralized" subroutine_root_id: 21 execution_probability: 1 execution_count: 1 instructions { - name: "all-reduce_mesh-2d_reduce-scatter_dim-0_unidir-ring_barrier_centralized_send_to_0" + name: "all-reduce_mesh-2d_reduce-scatter_stage-0_dim-1_unidir-ring_barrier_centralized_coordinator_recv_from_6" + opcode: "recv" + instruction_id: 19 + communication_groups { + group_ids: 6 + } + } + instructions { + name: "all-reduce_mesh-2d_reduce-scatter_stage-0_dim-1_unidir-ring_barrier_centralized_coordinator_send_to_6" opcode: "send" instruction_id: 20 communication_groups { - group_ids: 0 + group_ids: 6 } + operand_ids: 19 } instructions { - name: "all-reduce_mesh-2d_reduce-scatter_dim-0_unidir-ring_barrier_centralized_recv_from_0" - opcode: "recv" + name: "all-reduce_mesh-2d_reduce-scatter_stage-0_dim-1_unidir-ring_barrier_centralized_root_2" + opcode: "null" instruction_id: 21 - communication_groups { - group_ids: 0 - } operand_ids: 20 } } } instructions { - name: "all-reduce_mesh-2d_reduce-scatter_dim-0_mesh-1d_ccw_sendrecv_0" + name: "all-reduce_mesh-2d_reduce-scatter_stage-0_dim-1_mesh-1d_cw_sendrecv_0" opcode: "sendrecv" instruction_id: 22 bytes_in: 6 bytes_out: 6 communication_groups { - group_ids: 0 - group_ids: 0 + group_ids: 6 + group_ids: 6 } - operand_ids: 19 + operand_ids: 18 } instructions { - name: "all-reduce_mesh-2d_reduce-scatter_dim-0_mesh-1d_ccw_reduction_0" + name: "all-reduce_mesh-2d_reduce-scatter_stage-0_dim-1_mesh-1d_cw_reduction_0" opcode: "call" instruction_id: 23 operand_ids: 22 inner_subroutines { - name: "reduction_subroutine_ccw_phase_0" + name: "reduction_subroutine_cw_phase_0" subroutine_root_id: 26 execution_probability: 1 execution_count: 1 instructions { - name: "op1_ccw_phase_0" + name: "op1_cw_phase_0" opcode: "delay" instruction_id: 24 bytes_out: 6 } instructions { - name: "op2_ccw_phase_0" + name: "op2_cw_phase_0" opcode: "delay" instruction_id: 25 bytes_out: 6 } instructions { - name: "sum_ccw_phase_0" + name: "sum_cw_phase_0" opcode: "delay" instruction_id: 26 ops: 12 @@ -704,99 +856,305 @@ inner_subroutines { } } instructions { - name: "all-reduce_mesh-2d_reduce-scatter_dim-1" - opcode: "reduce-scatter" + name: "all-reduce_mesh-2d_reduce-scatter_stage-0_root" + opcode: "null" instruction_id: 27 - bytes_out: 12 + operand_ids: 8 + operand_ids: 17 + } + instructions { + name: "all-reduce_mesh-2d_reduce-scatter_stage-1_dim-0" + opcode: "reduce-scatter" + instruction_id: 28 + bytes_out: 24 communication_groups { + group_ids: 0 group_ids: 2 - group_ids: 6 } - operand_ids: 18 + operand_ids: 27 inner_subroutines { - name: "all-reduce_mesh-2d_reduce-scatter_dim-1_mesh-1d" + name: "all-reduce_mesh-2d_reduce-scatter_stage-1_dim-0_mesh-1d" subroutine_root_id: 33 execution_probability: 1 execution_count: 1 instructions { - name: "all-reduce_mesh-2d_reduce-scatter_dim-1_unidir-ring_barrier" + name: "all-reduce_mesh-2d_reduce-scatter_stage-1_dim-0_unidir-ring_barrier" opcode: "barrier" - instruction_id: 28 + instruction_id: 29 communication_groups { + group_ids: 0 group_ids: 2 - group_ids: 6 } inner_subroutines { - name: "all-reduce_mesh-2d_reduce-scatter_dim-1_unidir-ring_barrier_centralized" + name: "all-reduce_mesh-2d_reduce-scatter_stage-1_dim-0_unidir-ring_barrier_centralized" subroutine_root_id: 31 execution_probability: 1 execution_count: 1 instructions { - name: "all-reduce_mesh-2d_reduce-scatter_dim-1_unidir-ring_barrier_centralized_coordinator_recv_from_6" + name: "all-reduce_mesh-2d_reduce-scatter_stage-1_dim-0_unidir-ring_barrier_centralized_send_to_0" + opcode: "send" + instruction_id: 30 + communication_groups { + group_ids: 0 + } + } + instructions { + name: "all-reduce_mesh-2d_reduce-scatter_stage-1_dim-0_unidir-ring_barrier_centralized_recv_from_0" + opcode: "recv" + instruction_id: 31 + communication_groups { + group_ids: 0 + } + operand_ids: 30 + } + } + } + instructions { + name: "all-reduce_mesh-2d_reduce-scatter_stage-1_dim-0_mesh-1d_ccw_sendrecv_0" + opcode: "sendrecv" + instruction_id: 32 + bytes_in: 12 + bytes_out: 12 + communication_groups { + group_ids: 0 + group_ids: 0 + } + operand_ids: 29 + } + instructions { + name: "all-reduce_mesh-2d_reduce-scatter_stage-1_dim-0_mesh-1d_ccw_reduction_0" + opcode: "call" + instruction_id: 33 + operand_ids: 32 + inner_subroutines { + name: "reduction_subroutine_ccw_phase_0" + subroutine_root_id: 36 + execution_probability: 1 + execution_count: 1 + instructions { + name: "op1_ccw_phase_0" + opcode: "delay" + instruction_id: 34 + bytes_out: 12 + } + instructions { + name: "op2_ccw_phase_0" + opcode: "delay" + instruction_id: 35 + bytes_out: 12 + } + instructions { + name: "sum_ccw_phase_0" + opcode: "delay" + instruction_id: 36 + ops: 24 + operand_ids: 34 + operand_ids: 35 + } + } + } + } + } + instructions { + name: "all-reduce_mesh-2d_reduce-scatter_stage-1_dim-1" + opcode: "reduce-scatter" + instruction_id: 37 + bytes_out: 24 + communication_groups { + group_ids: 2 + group_ids: 6 + } + operand_ids: 27 + inner_subroutines { + name: "all-reduce_mesh-2d_reduce-scatter_stage-1_dim-1_mesh-1d" + subroutine_root_id: 43 + execution_probability: 1 + execution_count: 1 + instructions { + name: "all-reduce_mesh-2d_reduce-scatter_stage-1_dim-1_unidir-ring_barrier" + opcode: "barrier" + instruction_id: 38 + communication_groups { + group_ids: 2 + group_ids: 6 + } + inner_subroutines { + name: "all-reduce_mesh-2d_reduce-scatter_stage-1_dim-1_unidir-ring_barrier_centralized" + subroutine_root_id: 41 + execution_probability: 1 + execution_count: 1 + instructions { + name: "all-reduce_mesh-2d_reduce-scatter_stage-1_dim-1_unidir-ring_barrier_centralized_coordinator_recv_from_6" opcode: "recv" - instruction_id: 29 + instruction_id: 39 communication_groups { group_ids: 6 } } instructions { - name: "all-reduce_mesh-2d_reduce-scatter_dim-1_unidir-ring_barrier_centralized_coordinator_send_to_6" + name: "all-reduce_mesh-2d_reduce-scatter_stage-1_dim-1_unidir-ring_barrier_centralized_coordinator_send_to_6" opcode: "send" - instruction_id: 30 + instruction_id: 40 communication_groups { group_ids: 6 } - operand_ids: 29 + operand_ids: 39 } instructions { - name: "all-reduce_mesh-2d_reduce-scatter_dim-1_unidir-ring_barrier_centralized_root_2" + name: "all-reduce_mesh-2d_reduce-scatter_stage-1_dim-1_unidir-ring_barrier_centralized_root_2" opcode: "null" - instruction_id: 31 - operand_ids: 30 + instruction_id: 41 + operand_ids: 40 } } } instructions { - name: "all-reduce_mesh-2d_reduce-scatter_dim-1_mesh-1d_cw_sendrecv_0" + name: "all-reduce_mesh-2d_reduce-scatter_stage-1_dim-1_mesh-1d_cw_sendrecv_0" opcode: "sendrecv" - instruction_id: 32 - bytes_in: 6 - bytes_out: 6 + instruction_id: 42 + bytes_in: 12 + bytes_out: 12 communication_groups { group_ids: 6 group_ids: 6 } - operand_ids: 28 + operand_ids: 38 } instructions { - name: "all-reduce_mesh-2d_reduce-scatter_dim-1_mesh-1d_cw_reduction_0" + name: "all-reduce_mesh-2d_reduce-scatter_stage-1_dim-1_mesh-1d_cw_reduction_0" opcode: "call" - instruction_id: 33 - operand_ids: 32 + instruction_id: 43 + operand_ids: 42 inner_subroutines { name: "reduction_subroutine_cw_phase_0" - subroutine_root_id: 36 + subroutine_root_id: 46 execution_probability: 1 execution_count: 1 instructions { name: "op1_cw_phase_0" opcode: "delay" - instruction_id: 34 - bytes_out: 6 + instruction_id: 44 + bytes_out: 12 } instructions { name: "op2_cw_phase_0" opcode: "delay" - instruction_id: 35 - bytes_out: 6 + instruction_id: 45 + bytes_out: 12 } instructions { name: "sum_cw_phase_0" opcode: "delay" - instruction_id: 36 - ops: 12 - operand_ids: 34 - operand_ids: 35 + instruction_id: 46 + ops: 24 + operand_ids: 44 + operand_ids: 45 + } + } + } + } + } + instructions { + name: "all-reduce_mesh-2d_reduce-scatter_stage-1_root" + opcode: "null" + instruction_id: 47 + operand_ids: 28 + operand_ids: 37 + } + instructions { + name: "all-reduce_mesh-2d_reduce-scatter_conc" + opcode: "reduce-scatter" + instruction_id: 48 + bytes_out: 48 + communication_groups { + group_ids: 2 + group_ids: 3 + } + operand_ids: 47 + inner_subroutines { + name: "all-reduce_mesh-2d_reduce-scatter_conc_mesh-1d" + subroutine_root_id: 54 + execution_probability: 1 + execution_count: 1 + instructions { + name: "all-reduce_mesh-2d_reduce-scatter_conc_unidir-ring_barrier" + opcode: "barrier" + instruction_id: 49 + communication_groups { + group_ids: 2 + group_ids: 3 + } + inner_subroutines { + name: "all-reduce_mesh-2d_reduce-scatter_conc_unidir-ring_barrier_centralized" + subroutine_root_id: 52 + execution_probability: 1 + execution_count: 1 + instructions { + name: "all-reduce_mesh-2d_reduce-scatter_conc_unidir-ring_barrier_centralized_coordinator_recv_from_3" + opcode: "recv" + instruction_id: 50 + communication_groups { + group_ids: 3 + } + } + instructions { + name: "all-reduce_mesh-2d_reduce-scatter_conc_unidir-ring_barrier_centralized_coordinator_send_to_3" + opcode: "send" + instruction_id: 51 + communication_groups { + group_ids: 3 + } + operand_ids: 50 + } + instructions { + name: "all-reduce_mesh-2d_reduce-scatter_conc_unidir-ring_barrier_centralized_root_2" + opcode: "null" + instruction_id: 52 + operand_ids: 51 + } + } + } + instructions { + name: "all-reduce_mesh-2d_reduce-scatter_conc_mesh-1d_cw_sendrecv_0" + opcode: "sendrecv" + instruction_id: 53 + bytes_in: 24 + bytes_out: 24 + communication_groups { + group_ids: 3 + group_ids: 3 + } + operand_ids: 49 + } + instructions { + name: "all-reduce_mesh-2d_reduce-scatter_conc_mesh-1d_cw_reduction_0" + opcode: "call" + instruction_id: 54 + operand_ids: 53 + inner_subroutines { + name: "reduction_subroutine_cw_phase_0" + subroutine_root_id: 57 + execution_probability: 1 + execution_count: 1 + instructions { + name: "op1_cw_phase_0" + opcode: "delay" + instruction_id: 55 + bytes_out: 24 + } + instructions { + name: "op2_cw_phase_0" + opcode: "delay" + instruction_id: 56 + bytes_out: 24 + } + instructions { + name: "sum_cw_phase_0" + opcode: "delay" + instruction_id: 57 + ops: 48 + operand_ids: 55 + operand_ids: 56 } } } @@ -807,7 +1165,7 @@ inner_subroutines { instructions { name: "all-reduce_mesh-2d_all-gather" opcode: "all-gather" - instruction_id: 37 + instruction_id: 58 bytes_out: 48 communication_groups { group_ids: 0 @@ -822,200 +1180,341 @@ inner_subroutines { operand_ids: 7 inner_subroutines { name: "all-reduce_mesh-2d_all-gather_mesh-2d" - subroutine_root_id: 49 + subroutine_root_id: 83 execution_probability: 1 execution_count: 1 instructions { - name: "all-reduce_mesh-2d_all-gather_dim-conc" + name: "all-reduce_mesh-2d_all-gather_stage-0_dim-0" opcode: "all-gather" - instruction_id: 38 + instruction_id: 59 bytes_out: 12 communication_groups { + group_ids: 0 group_ids: 2 - group_ids: 3 } inner_subroutines { - name: "all-reduce_mesh-2d_all-gather_dim-conc_mesh-1d" - subroutine_root_id: 43 + name: "all-reduce_mesh-2d_all-gather_stage-0_dim-0_mesh-1d" + subroutine_root_id: 63 execution_probability: 1 execution_count: 1 instructions { - name: "all-reduce_mesh-2d_all-gather_dim-conc_mesh-1d_barrier" + name: "all-reduce_mesh-2d_all-gather_stage-0_dim-0_mesh-1d_barrier" opcode: "barrier" - instruction_id: 39 + instruction_id: 60 communication_groups { + group_ids: 0 group_ids: 2 - group_ids: 3 } inner_subroutines { - name: "all-reduce_mesh-2d_all-gather_dim-conc_mesh-1d_barrier_centralized" - subroutine_root_id: 42 + name: "all-reduce_mesh-2d_all-gather_stage-0_dim-0_mesh-1d_barrier_centralized" + subroutine_root_id: 62 execution_probability: 1 execution_count: 1 instructions { - name: "all-reduce_mesh-2d_all-gather_dim-conc_mesh-1d_barrier_centralized_coordinator_recv_from_3" + name: "all-reduce_mesh-2d_all-gather_stage-0_dim-0_mesh-1d_barrier_centralized_send_to_0" + opcode: "send" + instruction_id: 61 + communication_groups { + group_ids: 0 + } + } + instructions { + name: "all-reduce_mesh-2d_all-gather_stage-0_dim-0_mesh-1d_barrier_centralized_recv_from_0" opcode: "recv" - instruction_id: 40 + instruction_id: 62 communication_groups { - group_ids: 3 + group_ids: 0 + } + operand_ids: 61 + } + } + } + instructions { + name: "all-reduce_mesh-2d_all-gather_stage-0_dim-0_mesh-1d_ccw_sendrecv_0" + opcode: "sendrecv" + instruction_id: 63 + bytes_in: 6 + bytes_out: 6 + communication_groups { + group_ids: 0 + group_ids: 0 + } + operand_ids: 60 + } + } + } + instructions { + name: "all-reduce_mesh-2d_all-gather_stage-0_dim-1" + opcode: "all-gather" + instruction_id: 64 + bytes_out: 12 + communication_groups { + group_ids: 2 + group_ids: 6 + } + inner_subroutines { + name: "all-reduce_mesh-2d_all-gather_stage-0_dim-1_mesh-1d" + subroutine_root_id: 69 + execution_probability: 1 + execution_count: 1 + instructions { + name: "all-reduce_mesh-2d_all-gather_stage-0_dim-1_mesh-1d_barrier" + opcode: "barrier" + instruction_id: 65 + communication_groups { + group_ids: 2 + group_ids: 6 + } + inner_subroutines { + name: "all-reduce_mesh-2d_all-gather_stage-0_dim-1_mesh-1d_barrier_centralized" + subroutine_root_id: 68 + execution_probability: 1 + execution_count: 1 + instructions { + name: "all-reduce_mesh-2d_all-gather_stage-0_dim-1_mesh-1d_barrier_centralized_coordinator_recv_from_6" + opcode: "recv" + instruction_id: 66 + communication_groups { + group_ids: 6 } } instructions { - name: "all-reduce_mesh-2d_all-gather_dim-conc_mesh-1d_barrier_centralized_coordinator_send_to_3" + name: "all-reduce_mesh-2d_all-gather_stage-0_dim-1_mesh-1d_barrier_centralized_coordinator_send_to_6" opcode: "send" - instruction_id: 41 + instruction_id: 67 communication_groups { - group_ids: 3 + group_ids: 6 } - operand_ids: 40 + operand_ids: 66 } instructions { - name: "all-reduce_mesh-2d_all-gather_dim-conc_mesh-1d_barrier_centralized_root_2" + name: "all-reduce_mesh-2d_all-gather_stage-0_dim-1_mesh-1d_barrier_centralized_root_2" opcode: "null" - instruction_id: 42 - operand_ids: 41 + instruction_id: 68 + operand_ids: 67 } } } instructions { - name: "all-reduce_mesh-2d_all-gather_dim-conc_mesh-1d_cw_sendrecv_0" + name: "all-reduce_mesh-2d_all-gather_stage-0_dim-1_mesh-1d_cw_sendrecv_0" opcode: "sendrecv" - instruction_id: 43 + instruction_id: 69 bytes_in: 6 bytes_out: 6 communication_groups { - group_ids: 3 - group_ids: 3 + group_ids: 6 + group_ids: 6 } - operand_ids: 39 + operand_ids: 65 } } } instructions { - name: "all-reduce_mesh-2d_all-gather_dim-0" + name: "all-reduce_mesh-2d_all-gather_stage-0_root" + opcode: "null" + instruction_id: 70 + operand_ids: 59 + operand_ids: 64 + } + instructions { + name: "all-reduce_mesh-2d_all-gather_stage-1_dim-0" opcode: "all-gather" - instruction_id: 44 - bytes_out: 12 + instruction_id: 71 + bytes_out: 24 communication_groups { group_ids: 0 group_ids: 2 } - operand_ids: 38 + operand_ids: 70 inner_subroutines { - name: "all-reduce_mesh-2d_all-gather_dim-0_mesh-1d" - subroutine_root_id: 48 + name: "all-reduce_mesh-2d_all-gather_stage-1_dim-0_mesh-1d" + subroutine_root_id: 75 execution_probability: 1 execution_count: 1 instructions { - name: "all-reduce_mesh-2d_all-gather_dim-0_mesh-1d_barrier" + name: "all-reduce_mesh-2d_all-gather_stage-1_dim-0_mesh-1d_barrier" opcode: "barrier" - instruction_id: 45 + instruction_id: 72 communication_groups { group_ids: 0 group_ids: 2 } inner_subroutines { - name: "all-reduce_mesh-2d_all-gather_dim-0_mesh-1d_barrier_centralized" - subroutine_root_id: 47 + name: "all-reduce_mesh-2d_all-gather_stage-1_dim-0_mesh-1d_barrier_centralized" + subroutine_root_id: 74 execution_probability: 1 execution_count: 1 instructions { - name: "all-reduce_mesh-2d_all-gather_dim-0_mesh-1d_barrier_centralized_send_to_0" + name: "all-reduce_mesh-2d_all-gather_stage-1_dim-0_mesh-1d_barrier_centralized_send_to_0" opcode: "send" - instruction_id: 46 + instruction_id: 73 communication_groups { group_ids: 0 } } instructions { - name: "all-reduce_mesh-2d_all-gather_dim-0_mesh-1d_barrier_centralized_recv_from_0" + name: "all-reduce_mesh-2d_all-gather_stage-1_dim-0_mesh-1d_barrier_centralized_recv_from_0" opcode: "recv" - instruction_id: 47 + instruction_id: 74 communication_groups { group_ids: 0 } - operand_ids: 46 + operand_ids: 73 } } } instructions { - name: "all-reduce_mesh-2d_all-gather_dim-0_mesh-1d_ccw_sendrecv_0" + name: "all-reduce_mesh-2d_all-gather_stage-1_dim-0_mesh-1d_ccw_sendrecv_0" opcode: "sendrecv" - instruction_id: 48 - bytes_in: 6 - bytes_out: 6 + instruction_id: 75 + bytes_in: 12 + bytes_out: 12 communication_groups { group_ids: 0 group_ids: 0 } - operand_ids: 45 + operand_ids: 72 } } } instructions { - name: "all-reduce_mesh-2d_all-gather_dim-1" + name: "all-reduce_mesh-2d_all-gather_stage-1_dim-1" opcode: "all-gather" - instruction_id: 49 - bytes_out: 12 + instruction_id: 76 + bytes_out: 24 communication_groups { group_ids: 2 group_ids: 6 } - operand_ids: 44 + operand_ids: 70 inner_subroutines { - name: "all-reduce_mesh-2d_all-gather_dim-1_mesh-1d" - subroutine_root_id: 54 + name: "all-reduce_mesh-2d_all-gather_stage-1_dim-1_mesh-1d" + subroutine_root_id: 81 execution_probability: 1 execution_count: 1 instructions { - name: "all-reduce_mesh-2d_all-gather_dim-1_mesh-1d_barrier" + name: "all-reduce_mesh-2d_all-gather_stage-1_dim-1_mesh-1d_barrier" opcode: "barrier" - instruction_id: 50 + instruction_id: 77 communication_groups { group_ids: 2 group_ids: 6 } inner_subroutines { - name: "all-reduce_mesh-2d_all-gather_dim-1_mesh-1d_barrier_centralized" - subroutine_root_id: 53 + name: "all-reduce_mesh-2d_all-gather_stage-1_dim-1_mesh-1d_barrier_centralized" + subroutine_root_id: 80 execution_probability: 1 execution_count: 1 instructions { - name: "all-reduce_mesh-2d_all-gather_dim-1_mesh-1d_barrier_centralized_coordinator_recv_from_6" + name: "all-reduce_mesh-2d_all-gather_stage-1_dim-1_mesh-1d_barrier_centralized_coordinator_recv_from_6" opcode: "recv" - instruction_id: 51 + instruction_id: 78 communication_groups { group_ids: 6 } } instructions { - name: "all-reduce_mesh-2d_all-gather_dim-1_mesh-1d_barrier_centralized_coordinator_send_to_6" + name: "all-reduce_mesh-2d_all-gather_stage-1_dim-1_mesh-1d_barrier_centralized_coordinator_send_to_6" opcode: "send" - instruction_id: 52 + instruction_id: 79 communication_groups { group_ids: 6 } - operand_ids: 51 + operand_ids: 78 } instructions { - name: "all-reduce_mesh-2d_all-gather_dim-1_mesh-1d_barrier_centralized_root_2" + name: "all-reduce_mesh-2d_all-gather_stage-1_dim-1_mesh-1d_barrier_centralized_root_2" opcode: "null" - instruction_id: 53 - operand_ids: 52 + instruction_id: 80 + operand_ids: 79 } } } instructions { - name: "all-reduce_mesh-2d_all-gather_dim-1_mesh-1d_cw_sendrecv_0" + name: "all-reduce_mesh-2d_all-gather_stage-1_dim-1_mesh-1d_cw_sendrecv_0" opcode: "sendrecv" - instruction_id: 54 - bytes_in: 6 - bytes_out: 6 + instruction_id: 81 + bytes_in: 12 + bytes_out: 12 communication_groups { group_ids: 6 group_ids: 6 } - operand_ids: 50 + operand_ids: 77 + } + } + } + instructions { + name: "all-reduce_mesh-2d_all-gather_stage-1_root" + opcode: "null" + instruction_id: 82 + operand_ids: 71 + operand_ids: 76 + } + instructions { + name: "all-reduce_mesh-2d_all-gather_conc" + opcode: "all-gather" + instruction_id: 83 + bytes_out: 48 + communication_groups { + group_ids: 2 + group_ids: 3 + } + operand_ids: 82 + inner_subroutines { + name: "all-reduce_mesh-2d_all-gather_conc_mesh-1d" + subroutine_root_id: 88 + execution_probability: 1 + execution_count: 1 + instructions { + name: "all-reduce_mesh-2d_all-gather_conc_mesh-1d_barrier" + opcode: "barrier" + instruction_id: 84 + communication_groups { + group_ids: 2 + group_ids: 3 + } + inner_subroutines { + name: "all-reduce_mesh-2d_all-gather_conc_mesh-1d_barrier_centralized" + subroutine_root_id: 87 + execution_probability: 1 + execution_count: 1 + instructions { + name: "all-reduce_mesh-2d_all-gather_conc_mesh-1d_barrier_centralized_coordinator_recv_from_3" + opcode: "recv" + instruction_id: 85 + communication_groups { + group_ids: 3 + } + } + instructions { + name: "all-reduce_mesh-2d_all-gather_conc_mesh-1d_barrier_centralized_coordinator_send_to_3" + opcode: "send" + instruction_id: 86 + communication_groups { + group_ids: 3 + } + operand_ids: 85 + } + instructions { + name: "all-reduce_mesh-2d_all-gather_conc_mesh-1d_barrier_centralized_root_2" + opcode: "null" + instruction_id: 87 + operand_ids: 86 + } + } + } + instructions { + name: "all-reduce_mesh-2d_all-gather_conc_mesh-1d_cw_sendrecv_0" + opcode: "sendrecv" + instruction_id: 88 + bytes_in: 24 + bytes_out: 24 + communication_groups { + group_ids: 3 + group_ids: 3 + } + operand_ids: 84 } } } @@ -1023,8 +1522,69 @@ inner_subroutines { } } )proto"; - google::protobuf::TextFormat::ParseFromString(allreduce_str, - &allreduce_proto); + google::protobuf::TextFormat::ParseFromString(test_str, + &proto); + return proto; +} // NOLINT + +// Tests expanding 2D-Mesh all-reduce with barrier +TEST(Mesh2dAllReduce, WithBarrier) { + auto graph = absl::make_unique("test_graph", 2); + auto sub = absl::make_unique( + "test_subroutine", graph.get()); + auto sub_ptr = sub.get(); + sub_ptr->SetId(3); + graph->SetEntrySubroutine(std::move(sub)); + + ASSERT_OK_AND_ASSIGN(auto instr_1, paragraph::Instruction::Create( + paragraph::Opcode::kDelay, "first_instruction", sub_ptr)); + instr_1->SetOps(4); + + ASSERT_OK_AND_ASSIGN(auto allreduce, paragraph::Instruction::Create( + paragraph::Opcode::kAllReduce, "all-reduce", sub_ptr)); + allreduce->SetBytesOut(48); + paragraph::CommunicationGroup allreduce_group = {0, 1, 2, 3, 4, 5, 6, 7}; + allreduce->AppendCommunicationGroup(allreduce_group); + + auto reduction_sub = absl::make_unique( + "reduction_subroutine", graph.get()); + auto reduction_ptr = reduction_sub.get(); + ASSERT_OK_AND_ASSIGN(auto op1, paragraph::Instruction::Create( + paragraph::Opcode::kDelay, "op1", reduction_ptr)); + op1->SetBytesOut(48); + ASSERT_OK_AND_ASSIGN(auto op2, paragraph::Instruction::Create( + paragraph::Opcode::kDelay, "op2", reduction_ptr)); + op2->SetBytesOut(48); + ASSERT_OK_AND_ASSIGN(auto sum_op, paragraph::Instruction::Create( + paragraph::Opcode::kDelay, "sum", reduction_ptr, true)); + sum_op->SetOps(96); + sum_op->AddOperand(op1); + sum_op->AddOperand(op2); + allreduce->AppendInnerSubroutine(std::move(reduction_sub)); + + ASSERT_OK_AND_ASSIGN(auto instr_3, paragraph::Instruction::Create( + paragraph::Opcode::kDelay, "last_instruction", sub_ptr, true)); + instr_3->SetOps(4); + + nlohmann::json config = R"( + { + "all-reduce": { + "algorithm": "mesh-2d", + "concentration": 2, + "dimension_widths": [2, 2], + "barrier": { + "algorithm": "centralized" + } + } + } + )"_json; + + ASSERT_OK_AND_ASSIGN(auto translators, paragraph::CreateTranslators( + paragraph::TranslatorType::kCollective, config)); + EXPECT_OK(translators["all-reduce"]->Translate(allreduce)); + + paragraph::InstructionProto allreduce_proto = + Mesh2dAllReduce_with_barrier_test_proto(); EXPECT_TRUE(google::protobuf::util::MessageDifferencer::Equals( allreduce->ToProto().value(), allreduce_proto)); } diff --git a/paragraph/translation/allreduce/torus_2d_allreduce_translator_test.cc b/paragraph/translation/allreduce/torus_2d_allreduce_translator_test.cc index 97c20e1..9969525 100644 --- a/paragraph/translation/allreduce/torus_2d_allreduce_translator_test.cc +++ b/paragraph/translation/allreduce/torus_2d_allreduce_translator_test.cc @@ -24,61 +24,9 @@ #include "paragraph/shim/test_macros.h" #include "paragraph/translation/translation_map.h" -// Tests expanding 2D-Torus all-reduce -TEST(Torus2dAllReduce, NoBarrier) { - auto graph = absl::make_unique("test_graph", 2); - auto sub = absl::make_unique( - "test_subroutine", graph.get()); - auto sub_ptr = sub.get(); - sub_ptr->SetId(3); - graph->SetEntrySubroutine(std::move(sub)); - - ASSERT_OK_AND_ASSIGN(auto instr_1, paragraph::Instruction::Create( - paragraph::Opcode::kDelay, "first_instruction", sub_ptr)); - instr_1->SetOps(4); - - ASSERT_OK_AND_ASSIGN(auto allreduce, paragraph::Instruction::Create( - paragraph::Opcode::kAllReduce, "all-reduce", sub_ptr)); - allreduce->SetBytesOut(48); - paragraph::CommunicationGroup allreduce_group = {0, 1, 2, 3, 4, 5, 6, 7}; - allreduce->AppendCommunicationGroup(allreduce_group); - - auto reduction_sub = absl::make_unique( - "reduction_subroutine", graph.get()); - auto reduction_ptr = reduction_sub.get(); - ASSERT_OK_AND_ASSIGN(auto op1, paragraph::Instruction::Create( - paragraph::Opcode::kDelay, "op1", reduction_ptr)); - op1->SetBytesOut(48); - ASSERT_OK_AND_ASSIGN(auto op2, paragraph::Instruction::Create( - paragraph::Opcode::kDelay, "op2", reduction_ptr)); - op2->SetBytesOut(48); - ASSERT_OK_AND_ASSIGN(auto sum_op, paragraph::Instruction::Create( - paragraph::Opcode::kDelay, "sum", reduction_ptr, true)); - sum_op->SetOps(96); - sum_op->AddOperand(op1); - sum_op->AddOperand(op2); - allreduce->AppendInnerSubroutine(std::move(reduction_sub)); - - ASSERT_OK_AND_ASSIGN(auto instr_3, paragraph::Instruction::Create( - paragraph::Opcode::kDelay, "last_instruction", sub_ptr, true)); - instr_3->SetOps(4); - - nlohmann::json config = R"( - { - "all-reduce": { - "algorithm": "torus-2d", - "concentration": 2, - "dimension_widths": [2, 2] - } - } - )"_json; - - ASSERT_OK_AND_ASSIGN(auto translators, paragraph::CreateTranslators( - paragraph::TranslatorType::kCollective, config)); - EXPECT_OK(translators["all-reduce"]->Translate(allreduce)); - - paragraph::InstructionProto allreduce_proto; - std::string allreduce_str = +paragraph::InstructionProto no_barrier_test_proto() { + paragraph::InstructionProto proto; + std::string test_str = R"proto( name: "all-reduce" opcode: "all-reduce" @@ -96,7 +44,7 @@ communication_groups { } inner_subroutines { name: "all-reduce_torus-2d" - subroutine_root_id: 50 + subroutine_root_id: 80 execution_probability: 1 execution_count: 1 instructions { @@ -116,50 +64,50 @@ inner_subroutines { } inner_subroutines { name: "all-reduce_torus-2d_reduce-scatter_torus-2d" - subroutine_root_id: 36 + subroutine_root_id: 66 execution_probability: 1 execution_count: 1 instructions { - name: "all-reduce_torus-2d_reduce-scatter_dim-conc" + name: "all-reduce_torus-2d_reduce-scatter_stage-0_dim-0" opcode: "reduce-scatter" instruction_id: 8 bytes_out: 12 communication_groups { + group_ids: 0 group_ids: 2 - group_ids: 3 } inner_subroutines { - name: "all-reduce_torus-2d_reduce-scatter_dim-conc_bidir-ring" + name: "all-reduce_torus-2d_reduce-scatter_stage-0_dim-0_bidir-ring" subroutine_root_id: 21 execution_probability: 1 execution_count: 1 instructions { - name: "all-reduce_torus-2d_reduce-scatter_dim-conc_bidir-ring_cw" + name: "all-reduce_torus-2d_reduce-scatter_stage-0_dim-0_bidir-ring_cw" opcode: "reduce-scatter" instruction_id: 9 bytes_out: 6 communication_groups { + group_ids: 0 group_ids: 2 - group_ids: 3 } inner_subroutines { - name: "all-reduce_torus-2d_reduce-scatter_dim-conc_bidir-ring_cw_unidir-ring" + name: "all-reduce_torus-2d_reduce-scatter_stage-0_dim-0_bidir-ring_cw_unidir-ring" subroutine_root_id: 11 execution_probability: 1 execution_count: 1 instructions { - name: "all-reduce_torus-2d_reduce-scatter_dim-conc_bidir-ring_cw_unidir-ring_sendrecv_1" + name: "all-reduce_torus-2d_reduce-scatter_stage-0_dim-0_bidir-ring_cw_unidir-ring_sendrecv_1" opcode: "sendrecv" instruction_id: 10 bytes_in: 3 bytes_out: 3 communication_groups { - group_ids: 3 - group_ids: 3 + group_ids: 0 + group_ids: 0 } } instructions { - name: "all-reduce_torus-2d_reduce-scatter_dim-conc_bidir-ring_cw_unidir-ring_reduction_1" + name: "all-reduce_torus-2d_reduce-scatter_stage-0_dim-0_bidir-ring_cw_unidir-ring_reduction_1" opcode: "call" instruction_id: 11 operand_ids: 10 @@ -193,32 +141,32 @@ inner_subroutines { } } instructions { - name: "all-reduce_torus-2d_reduce-scatter_dim-conc_bidir-ring_ccw" + name: "all-reduce_torus-2d_reduce-scatter_stage-0_dim-0_bidir-ring_ccw" opcode: "reduce-scatter" instruction_id: 15 bytes_out: 6 communication_groups { - group_ids: 3 group_ids: 2 + group_ids: 0 } inner_subroutines { - name: "all-reduce_torus-2d_reduce-scatter_dim-conc_bidir-ring_ccw_unidir-ring" + name: "all-reduce_torus-2d_reduce-scatter_stage-0_dim-0_bidir-ring_ccw_unidir-ring" subroutine_root_id: 17 execution_probability: 1 execution_count: 1 instructions { - name: "all-reduce_torus-2d_reduce-scatter_dim-conc_bidir-ring_ccw_unidir-ring_sendrecv_1" + name: "all-reduce_torus-2d_reduce-scatter_stage-0_dim-0_bidir-ring_ccw_unidir-ring_sendrecv_1" opcode: "sendrecv" instruction_id: 16 bytes_in: 3 bytes_out: 3 communication_groups { - group_ids: 3 - group_ids: 3 + group_ids: 0 + group_ids: 0 } } instructions { - name: "all-reduce_torus-2d_reduce-scatter_dim-conc_bidir-ring_ccw_unidir-ring_reduction_1" + name: "all-reduce_torus-2d_reduce-scatter_stage-0_dim-0_bidir-ring_ccw_unidir-ring_reduction_1" opcode: "call" instruction_id: 17 operand_ids: 16 @@ -252,7 +200,7 @@ inner_subroutines { } } instructions { - name: "all-reduce_torus-2d_reduce-scatter_dim-conc_bidir-ring_root_2" + name: "all-reduce_torus-2d_reduce-scatter_stage-0_dim-0_bidir-ring_root_2" opcode: "null" instruction_id: 21 operand_ids: 9 @@ -261,47 +209,46 @@ inner_subroutines { } } instructions { - name: "all-reduce_torus-2d_reduce-scatter_dim-0" + name: "all-reduce_torus-2d_reduce-scatter_stage-0_dim-1" opcode: "reduce-scatter" instruction_id: 22 bytes_out: 12 communication_groups { - group_ids: 0 group_ids: 2 + group_ids: 6 } - operand_ids: 8 inner_subroutines { - name: "all-reduce_torus-2d_reduce-scatter_dim-0_bidir-ring" + name: "all-reduce_torus-2d_reduce-scatter_stage-0_dim-1_bidir-ring" subroutine_root_id: 35 execution_probability: 1 execution_count: 1 instructions { - name: "all-reduce_torus-2d_reduce-scatter_dim-0_bidir-ring_cw" + name: "all-reduce_torus-2d_reduce-scatter_stage-0_dim-1_bidir-ring_cw" opcode: "reduce-scatter" instruction_id: 23 bytes_out: 6 communication_groups { - group_ids: 0 group_ids: 2 + group_ids: 6 } inner_subroutines { - name: "all-reduce_torus-2d_reduce-scatter_dim-0_bidir-ring_cw_unidir-ring" + name: "all-reduce_torus-2d_reduce-scatter_stage-0_dim-1_bidir-ring_cw_unidir-ring" subroutine_root_id: 25 execution_probability: 1 execution_count: 1 instructions { - name: "all-reduce_torus-2d_reduce-scatter_dim-0_bidir-ring_cw_unidir-ring_sendrecv_1" + name: "all-reduce_torus-2d_reduce-scatter_stage-0_dim-1_bidir-ring_cw_unidir-ring_sendrecv_1" opcode: "sendrecv" instruction_id: 24 bytes_in: 3 bytes_out: 3 communication_groups { - group_ids: 0 - group_ids: 0 + group_ids: 6 + group_ids: 6 } } instructions { - name: "all-reduce_torus-2d_reduce-scatter_dim-0_bidir-ring_cw_unidir-ring_reduction_1" + name: "all-reduce_torus-2d_reduce-scatter_stage-0_dim-1_bidir-ring_cw_unidir-ring_reduction_1" opcode: "call" instruction_id: 25 operand_ids: 24 @@ -335,32 +282,32 @@ inner_subroutines { } } instructions { - name: "all-reduce_torus-2d_reduce-scatter_dim-0_bidir-ring_ccw" + name: "all-reduce_torus-2d_reduce-scatter_stage-0_dim-1_bidir-ring_ccw" opcode: "reduce-scatter" instruction_id: 29 bytes_out: 6 communication_groups { + group_ids: 6 group_ids: 2 - group_ids: 0 } inner_subroutines { - name: "all-reduce_torus-2d_reduce-scatter_dim-0_bidir-ring_ccw_unidir-ring" + name: "all-reduce_torus-2d_reduce-scatter_stage-0_dim-1_bidir-ring_ccw_unidir-ring" subroutine_root_id: 31 execution_probability: 1 execution_count: 1 instructions { - name: "all-reduce_torus-2d_reduce-scatter_dim-0_bidir-ring_ccw_unidir-ring_sendrecv_1" + name: "all-reduce_torus-2d_reduce-scatter_stage-0_dim-1_bidir-ring_ccw_unidir-ring_sendrecv_1" opcode: "sendrecv" instruction_id: 30 bytes_in: 3 bytes_out: 3 communication_groups { - group_ids: 0 - group_ids: 0 + group_ids: 6 + group_ids: 6 } } instructions { - name: "all-reduce_torus-2d_reduce-scatter_dim-0_bidir-ring_ccw_unidir-ring_reduction_1" + name: "all-reduce_torus-2d_reduce-scatter_stage-0_dim-1_bidir-ring_ccw_unidir-ring_reduction_1" opcode: "call" instruction_id: 31 operand_ids: 30 @@ -394,7 +341,7 @@ inner_subroutines { } } instructions { - name: "all-reduce_torus-2d_reduce-scatter_dim-0_bidir-ring_root_2" + name: "all-reduce_torus-2d_reduce-scatter_stage-0_dim-1_bidir-ring_root_2" opcode: "null" instruction_id: 35 operand_ids: 23 @@ -403,307 +350,1286 @@ inner_subroutines { } } instructions { - name: "all-reduce_torus-2d_reduce-scatter_dim-1" - opcode: "reduce-scatter" + name: "all-reduce_torus-2d_reduce-scatter_stage-0_root" + opcode: "null" instruction_id: 36 - bytes_out: 12 + operand_ids: 8 + operand_ids: 22 + } + instructions { + name: "all-reduce_torus-2d_reduce-scatter_stage-1_dim-0" + opcode: "reduce-scatter" + instruction_id: 37 + bytes_out: 24 communication_groups { + group_ids: 0 group_ids: 2 - group_ids: 6 } - operand_ids: 22 + operand_ids: 36 inner_subroutines { - name: "all-reduce_torus-2d_reduce-scatter_dim-1_bidir-ring" - subroutine_root_id: 49 + name: "all-reduce_torus-2d_reduce-scatter_stage-1_dim-0_bidir-ring" + subroutine_root_id: 50 execution_probability: 1 execution_count: 1 instructions { - name: "all-reduce_torus-2d_reduce-scatter_dim-1_bidir-ring_cw" + name: "all-reduce_torus-2d_reduce-scatter_stage-1_dim-0_bidir-ring_cw" opcode: "reduce-scatter" - instruction_id: 37 - bytes_out: 6 + instruction_id: 38 + bytes_out: 12 communication_groups { + group_ids: 0 group_ids: 2 - group_ids: 6 } inner_subroutines { - name: "all-reduce_torus-2d_reduce-scatter_dim-1_bidir-ring_cw_unidir-ring" - subroutine_root_id: 39 + name: "all-reduce_torus-2d_reduce-scatter_stage-1_dim-0_bidir-ring_cw_unidir-ring" + subroutine_root_id: 40 execution_probability: 1 execution_count: 1 instructions { - name: "all-reduce_torus-2d_reduce-scatter_dim-1_bidir-ring_cw_unidir-ring_sendrecv_1" + name: "all-reduce_torus-2d_reduce-scatter_stage-1_dim-0_bidir-ring_cw_unidir-ring_sendrecv_1" opcode: "sendrecv" - instruction_id: 38 - bytes_in: 3 - bytes_out: 3 + instruction_id: 39 + bytes_in: 6 + bytes_out: 6 communication_groups { - group_ids: 6 - group_ids: 6 + group_ids: 0 + group_ids: 0 } } instructions { - name: "all-reduce_torus-2d_reduce-scatter_dim-1_bidir-ring_cw_unidir-ring_reduction_1" + name: "all-reduce_torus-2d_reduce-scatter_stage-1_dim-0_bidir-ring_cw_unidir-ring_reduction_1" opcode: "call" - instruction_id: 39 - operand_ids: 38 + instruction_id: 40 + operand_ids: 39 inner_subroutines { name: "reduction_subroutine_phase_1" - subroutine_root_id: 42 + subroutine_root_id: 43 execution_probability: 1 execution_count: 1 instructions { name: "op1_phase_1" opcode: "delay" - instruction_id: 40 - bytes_out: 3 + instruction_id: 41 + bytes_out: 6 } instructions { name: "op2_phase_1" opcode: "delay" - instruction_id: 41 - bytes_out: 3 + instruction_id: 42 + bytes_out: 6 } instructions { name: "sum_phase_1" opcode: "delay" - instruction_id: 42 - ops: 6 - operand_ids: 40 + instruction_id: 43 + ops: 12 operand_ids: 41 + operand_ids: 42 } } } } } instructions { - name: "all-reduce_torus-2d_reduce-scatter_dim-1_bidir-ring_ccw" + name: "all-reduce_torus-2d_reduce-scatter_stage-1_dim-0_bidir-ring_ccw" opcode: "reduce-scatter" - instruction_id: 43 - bytes_out: 6 + instruction_id: 44 + bytes_out: 12 communication_groups { - group_ids: 6 group_ids: 2 + group_ids: 0 } inner_subroutines { - name: "all-reduce_torus-2d_reduce-scatter_dim-1_bidir-ring_ccw_unidir-ring" - subroutine_root_id: 45 + name: "all-reduce_torus-2d_reduce-scatter_stage-1_dim-0_bidir-ring_ccw_unidir-ring" + subroutine_root_id: 46 execution_probability: 1 execution_count: 1 instructions { - name: "all-reduce_torus-2d_reduce-scatter_dim-1_bidir-ring_ccw_unidir-ring_sendrecv_1" + name: "all-reduce_torus-2d_reduce-scatter_stage-1_dim-0_bidir-ring_ccw_unidir-ring_sendrecv_1" opcode: "sendrecv" - instruction_id: 44 - bytes_in: 3 - bytes_out: 3 + instruction_id: 45 + bytes_in: 6 + bytes_out: 6 communication_groups { - group_ids: 6 - group_ids: 6 + group_ids: 0 + group_ids: 0 } } instructions { - name: "all-reduce_torus-2d_reduce-scatter_dim-1_bidir-ring_ccw_unidir-ring_reduction_1" + name: "all-reduce_torus-2d_reduce-scatter_stage-1_dim-0_bidir-ring_ccw_unidir-ring_reduction_1" opcode: "call" - instruction_id: 45 - operand_ids: 44 + instruction_id: 46 + operand_ids: 45 inner_subroutines { name: "reduction_subroutine_phase_1" - subroutine_root_id: 48 + subroutine_root_id: 49 execution_probability: 1 execution_count: 1 instructions { name: "op1_phase_1" opcode: "delay" - instruction_id: 46 - bytes_out: 3 + instruction_id: 47 + bytes_out: 6 } instructions { name: "op2_phase_1" opcode: "delay" - instruction_id: 47 - bytes_out: 3 + instruction_id: 48 + bytes_out: 6 } instructions { name: "sum_phase_1" opcode: "delay" - instruction_id: 48 - ops: 6 - operand_ids: 46 + instruction_id: 49 + ops: 12 operand_ids: 47 + operand_ids: 48 } } } } } instructions { - name: "all-reduce_torus-2d_reduce-scatter_dim-1_bidir-ring_root_2" + name: "all-reduce_torus-2d_reduce-scatter_stage-1_dim-0_bidir-ring_root_2" opcode: "null" - instruction_id: 49 - operand_ids: 37 - operand_ids: 43 + instruction_id: 50 + operand_ids: 38 + operand_ids: 44 } } } - } - } - instructions { - name: "all-reduce_torus-2d_all-gather" - opcode: "all-gather" - instruction_id: 50 - bytes_out: 48 - communication_groups { - group_ids: 0 - group_ids: 1 - group_ids: 2 - group_ids: 3 - group_ids: 4 - group_ids: 5 - group_ids: 6 - group_ids: 7 - } - operand_ids: 7 - inner_subroutines { - name: "all-reduce_torus-2d_all-gather_torus-2d" - subroutine_root_id: 63 - execution_probability: 1 - execution_count: 1 instructions { - name: "all-reduce_torus-2d_all-gather_dim-conc" - opcode: "all-gather" + name: "all-reduce_torus-2d_reduce-scatter_stage-1_dim-1" + opcode: "reduce-scatter" instruction_id: 51 - bytes_out: 12 + bytes_out: 24 communication_groups { group_ids: 2 - group_ids: 3 + group_ids: 6 } + operand_ids: 36 inner_subroutines { - name: "all-reduce_torus-2d_all-gather_dim-conc_bidir-ring" - subroutine_root_id: 56 + name: "all-reduce_torus-2d_reduce-scatter_stage-1_dim-1_bidir-ring" + subroutine_root_id: 64 execution_probability: 1 execution_count: 1 instructions { - name: "all-reduce_torus-2d_all-gather_dim-conc_bidir-ring_cw" - opcode: "all-gather" + name: "all-reduce_torus-2d_reduce-scatter_stage-1_dim-1_bidir-ring_cw" + opcode: "reduce-scatter" instruction_id: 52 - bytes_out: 6 + bytes_out: 12 communication_groups { group_ids: 2 - group_ids: 3 + group_ids: 6 } inner_subroutines { - name: "all-reduce_torus-2d_all-gather_dim-conc_bidir-ring_cw_unidir-ring" - subroutine_root_id: 53 + name: "all-reduce_torus-2d_reduce-scatter_stage-1_dim-1_bidir-ring_cw_unidir-ring" + subroutine_root_id: 54 execution_probability: 1 execution_count: 1 instructions { - name: "all-reduce_torus-2d_all-gather_dim-conc_bidir-ring_cw_unidir-ring_sendrecv_1" + name: "all-reduce_torus-2d_reduce-scatter_stage-1_dim-1_bidir-ring_cw_unidir-ring_sendrecv_1" opcode: "sendrecv" instruction_id: 53 - bytes_in: 3 - bytes_out: 3 + bytes_in: 6 + bytes_out: 6 communication_groups { - group_ids: 3 - group_ids: 3 + group_ids: 6 + group_ids: 6 + } + } + instructions { + name: "all-reduce_torus-2d_reduce-scatter_stage-1_dim-1_bidir-ring_cw_unidir-ring_reduction_1" + opcode: "call" + instruction_id: 54 + operand_ids: 53 + inner_subroutines { + name: "reduction_subroutine_phase_1" + subroutine_root_id: 57 + execution_probability: 1 + execution_count: 1 + instructions { + name: "op1_phase_1" + opcode: "delay" + instruction_id: 55 + bytes_out: 6 + } + instructions { + name: "op2_phase_1" + opcode: "delay" + instruction_id: 56 + bytes_out: 6 + } + instructions { + name: "sum_phase_1" + opcode: "delay" + instruction_id: 57 + ops: 12 + operand_ids: 55 + operand_ids: 56 + } } } } } instructions { - name: "all-reduce_torus-2d_all-gather_dim-conc_bidir-ring_ccw" - opcode: "all-gather" - instruction_id: 54 - bytes_out: 6 + name: "all-reduce_torus-2d_reduce-scatter_stage-1_dim-1_bidir-ring_ccw" + opcode: "reduce-scatter" + instruction_id: 58 + bytes_out: 12 communication_groups { - group_ids: 3 + group_ids: 6 group_ids: 2 } inner_subroutines { - name: "all-reduce_torus-2d_all-gather_dim-conc_bidir-ring_ccw_unidir-ring" - subroutine_root_id: 55 + name: "all-reduce_torus-2d_reduce-scatter_stage-1_dim-1_bidir-ring_ccw_unidir-ring" + subroutine_root_id: 60 execution_probability: 1 execution_count: 1 instructions { - name: "all-reduce_torus-2d_all-gather_dim-conc_bidir-ring_ccw_unidir-ring_sendrecv_1" + name: "all-reduce_torus-2d_reduce-scatter_stage-1_dim-1_bidir-ring_ccw_unidir-ring_sendrecv_1" opcode: "sendrecv" - instruction_id: 55 - bytes_in: 3 - bytes_out: 3 + instruction_id: 59 + bytes_in: 6 + bytes_out: 6 communication_groups { - group_ids: 3 - group_ids: 3 + group_ids: 6 + group_ids: 6 + } + } + instructions { + name: "all-reduce_torus-2d_reduce-scatter_stage-1_dim-1_bidir-ring_ccw_unidir-ring_reduction_1" + opcode: "call" + instruction_id: 60 + operand_ids: 59 + inner_subroutines { + name: "reduction_subroutine_phase_1" + subroutine_root_id: 63 + execution_probability: 1 + execution_count: 1 + instructions { + name: "op1_phase_1" + opcode: "delay" + instruction_id: 61 + bytes_out: 6 + } + instructions { + name: "op2_phase_1" + opcode: "delay" + instruction_id: 62 + bytes_out: 6 + } + instructions { + name: "sum_phase_1" + opcode: "delay" + instruction_id: 63 + ops: 12 + operand_ids: 61 + operand_ids: 62 + } } } } } instructions { - name: "all-reduce_torus-2d_all-gather_dim-conc_bidir-ring_root_2" + name: "all-reduce_torus-2d_reduce-scatter_stage-1_dim-1_bidir-ring_root_2" opcode: "null" - instruction_id: 56 + instruction_id: 64 operand_ids: 52 - operand_ids: 54 + operand_ids: 58 } } } instructions { - name: "all-reduce_torus-2d_all-gather_dim-0" - opcode: "all-gather" - instruction_id: 57 - bytes_out: 12 + name: "all-reduce_torus-2d_reduce-scatter_stage-1_root" + opcode: "null" + instruction_id: 65 + operand_ids: 37 + operand_ids: 51 + } + instructions { + name: "all-reduce_torus-2d_reduce-scatter_conc" + opcode: "reduce-scatter" + instruction_id: 66 + bytes_out: 48 communication_groups { - group_ids: 0 group_ids: 2 + group_ids: 3 } - operand_ids: 51 + operand_ids: 65 inner_subroutines { - name: "all-reduce_torus-2d_all-gather_dim-0_bidir-ring" - subroutine_root_id: 62 + name: "all-reduce_torus-2d_reduce-scatter_conc_bidir-ring" + subroutine_root_id: 79 execution_probability: 1 execution_count: 1 instructions { - name: "all-reduce_torus-2d_all-gather_dim-0_bidir-ring_cw" - opcode: "all-gather" - instruction_id: 58 - bytes_out: 6 + name: "all-reduce_torus-2d_reduce-scatter_conc_bidir-ring_cw" + opcode: "reduce-scatter" + instruction_id: 67 + bytes_out: 24 communication_groups { - group_ids: 0 group_ids: 2 + group_ids: 3 } inner_subroutines { - name: "all-reduce_torus-2d_all-gather_dim-0_bidir-ring_cw_unidir-ring" - subroutine_root_id: 59 + name: "all-reduce_torus-2d_reduce-scatter_conc_bidir-ring_cw_unidir-ring" + subroutine_root_id: 69 execution_probability: 1 execution_count: 1 instructions { - name: "all-reduce_torus-2d_all-gather_dim-0_bidir-ring_cw_unidir-ring_sendrecv_1" + name: "all-reduce_torus-2d_reduce-scatter_conc_bidir-ring_cw_unidir-ring_sendrecv_1" opcode: "sendrecv" - instruction_id: 59 - bytes_in: 3 - bytes_out: 3 + instruction_id: 68 + bytes_in: 12 + bytes_out: 12 communication_groups { - group_ids: 0 - group_ids: 0 + group_ids: 3 + group_ids: 3 } } - } - } - instructions { - name: "all-reduce_torus-2d_all-gather_dim-0_bidir-ring_ccw" - opcode: "all-gather" - instruction_id: 60 - bytes_out: 6 - communication_groups { - group_ids: 2 - group_ids: 0 - } + instructions { + name: "all-reduce_torus-2d_reduce-scatter_conc_bidir-ring_cw_unidir-ring_reduction_1" + opcode: "call" + instruction_id: 69 + operand_ids: 68 + inner_subroutines { + name: "reduction_subroutine_phase_1" + subroutine_root_id: 72 + execution_probability: 1 + execution_count: 1 + instructions { + name: "op1_phase_1" + opcode: "delay" + instruction_id: 70 + bytes_out: 12 + } + instructions { + name: "op2_phase_1" + opcode: "delay" + instruction_id: 71 + bytes_out: 12 + } + instructions { + name: "sum_phase_1" + opcode: "delay" + instruction_id: 72 + ops: 24 + operand_ids: 70 + operand_ids: 71 + } + } + } + } + } + instructions { + name: "all-reduce_torus-2d_reduce-scatter_conc_bidir-ring_ccw" + opcode: "reduce-scatter" + instruction_id: 73 + bytes_out: 24 + communication_groups { + group_ids: 3 + group_ids: 2 + } + inner_subroutines { + name: "all-reduce_torus-2d_reduce-scatter_conc_bidir-ring_ccw_unidir-ring" + subroutine_root_id: 75 + execution_probability: 1 + execution_count: 1 + instructions { + name: "all-reduce_torus-2d_reduce-scatter_conc_bidir-ring_ccw_unidir-ring_sendrecv_1" + opcode: "sendrecv" + instruction_id: 74 + bytes_in: 12 + bytes_out: 12 + communication_groups { + group_ids: 3 + group_ids: 3 + } + } + instructions { + name: "all-reduce_torus-2d_reduce-scatter_conc_bidir-ring_ccw_unidir-ring_reduction_1" + opcode: "call" + instruction_id: 75 + operand_ids: 74 + inner_subroutines { + name: "reduction_subroutine_phase_1" + subroutine_root_id: 78 + execution_probability: 1 + execution_count: 1 + instructions { + name: "op1_phase_1" + opcode: "delay" + instruction_id: 76 + bytes_out: 12 + } + instructions { + name: "op2_phase_1" + opcode: "delay" + instruction_id: 77 + bytes_out: 12 + } + instructions { + name: "sum_phase_1" + opcode: "delay" + instruction_id: 78 + ops: 24 + operand_ids: 76 + operand_ids: 77 + } + } + } + } + } + instructions { + name: "all-reduce_torus-2d_reduce-scatter_conc_bidir-ring_root_2" + opcode: "null" + instruction_id: 79 + operand_ids: 67 + operand_ids: 73 + } + } + } + } + } + instructions { + name: "all-reduce_torus-2d_all-gather" + opcode: "all-gather" + instruction_id: 80 + bytes_out: 48 + communication_groups { + group_ids: 0 + group_ids: 1 + group_ids: 2 + group_ids: 3 + group_ids: 4 + group_ids: 5 + group_ids: 6 + group_ids: 7 + } + operand_ids: 7 + inner_subroutines { + name: "all-reduce_torus-2d_all-gather_torus-2d" + subroutine_root_id: 107 + execution_probability: 1 + execution_count: 1 + instructions { + name: "all-reduce_torus-2d_all-gather_stage-0_dim-0" + opcode: "all-gather" + instruction_id: 81 + bytes_out: 12 + communication_groups { + group_ids: 0 + group_ids: 2 + } + inner_subroutines { + name: "all-reduce_torus-2d_all-gather_stage-0_dim-0_bidir-ring" + subroutine_root_id: 86 + execution_probability: 1 + execution_count: 1 + instructions { + name: "all-reduce_torus-2d_all-gather_stage-0_dim-0_bidir-ring_cw" + opcode: "all-gather" + instruction_id: 82 + bytes_out: 6 + communication_groups { + group_ids: 0 + group_ids: 2 + } + inner_subroutines { + name: "all-reduce_torus-2d_all-gather_stage-0_dim-0_bidir-ring_cw_unidir-ring" + subroutine_root_id: 83 + execution_probability: 1 + execution_count: 1 + instructions { + name: "all-reduce_torus-2d_all-gather_stage-0_dim-0_bidir-ring_cw_unidir-ring_sendrecv_1" + opcode: "sendrecv" + instruction_id: 83 + bytes_in: 3 + bytes_out: 3 + communication_groups { + group_ids: 0 + group_ids: 0 + } + } + } + } + instructions { + name: "all-reduce_torus-2d_all-gather_stage-0_dim-0_bidir-ring_ccw" + opcode: "all-gather" + instruction_id: 84 + bytes_out: 6 + communication_groups { + group_ids: 2 + group_ids: 0 + } + inner_subroutines { + name: "all-reduce_torus-2d_all-gather_stage-0_dim-0_bidir-ring_ccw_unidir-ring" + subroutine_root_id: 85 + execution_probability: 1 + execution_count: 1 + instructions { + name: "all-reduce_torus-2d_all-gather_stage-0_dim-0_bidir-ring_ccw_unidir-ring_sendrecv_1" + opcode: "sendrecv" + instruction_id: 85 + bytes_in: 3 + bytes_out: 3 + communication_groups { + group_ids: 0 + group_ids: 0 + } + } + } + } + instructions { + name: "all-reduce_torus-2d_all-gather_stage-0_dim-0_bidir-ring_root_2" + opcode: "null" + instruction_id: 86 + operand_ids: 82 + operand_ids: 84 + } + } + } + instructions { + name: "all-reduce_torus-2d_all-gather_stage-0_dim-1" + opcode: "all-gather" + instruction_id: 87 + bytes_out: 12 + communication_groups { + group_ids: 2 + group_ids: 6 + } + inner_subroutines { + name: "all-reduce_torus-2d_all-gather_stage-0_dim-1_bidir-ring" + subroutine_root_id: 92 + execution_probability: 1 + execution_count: 1 + instructions { + name: "all-reduce_torus-2d_all-gather_stage-0_dim-1_bidir-ring_cw" + opcode: "all-gather" + instruction_id: 88 + bytes_out: 6 + communication_groups { + group_ids: 2 + group_ids: 6 + } + inner_subroutines { + name: "all-reduce_torus-2d_all-gather_stage-0_dim-1_bidir-ring_cw_unidir-ring" + subroutine_root_id: 89 + execution_probability: 1 + execution_count: 1 + instructions { + name: "all-reduce_torus-2d_all-gather_stage-0_dim-1_bidir-ring_cw_unidir-ring_sendrecv_1" + opcode: "sendrecv" + instruction_id: 89 + bytes_in: 3 + bytes_out: 3 + communication_groups { + group_ids: 6 + group_ids: 6 + } + } + } + } + instructions { + name: "all-reduce_torus-2d_all-gather_stage-0_dim-1_bidir-ring_ccw" + opcode: "all-gather" + instruction_id: 90 + bytes_out: 6 + communication_groups { + group_ids: 6 + group_ids: 2 + } + inner_subroutines { + name: "all-reduce_torus-2d_all-gather_stage-0_dim-1_bidir-ring_ccw_unidir-ring" + subroutine_root_id: 91 + execution_probability: 1 + execution_count: 1 + instructions { + name: "all-reduce_torus-2d_all-gather_stage-0_dim-1_bidir-ring_ccw_unidir-ring_sendrecv_1" + opcode: "sendrecv" + instruction_id: 91 + bytes_in: 3 + bytes_out: 3 + communication_groups { + group_ids: 6 + group_ids: 6 + } + } + } + } + instructions { + name: "all-reduce_torus-2d_all-gather_stage-0_dim-1_bidir-ring_root_2" + opcode: "null" + instruction_id: 92 + operand_ids: 88 + operand_ids: 90 + } + } + } + instructions { + name: "all-reduce_torus-2d_all-gather_stage-0_root" + opcode: "null" + instruction_id: 93 + operand_ids: 81 + operand_ids: 87 + } + instructions { + name: "all-reduce_torus-2d_all-gather_stage-1_dim-0" + opcode: "all-gather" + instruction_id: 94 + bytes_out: 24 + communication_groups { + group_ids: 0 + group_ids: 2 + } + operand_ids: 93 + inner_subroutines { + name: "all-reduce_torus-2d_all-gather_stage-1_dim-0_bidir-ring" + subroutine_root_id: 99 + execution_probability: 1 + execution_count: 1 + instructions { + name: "all-reduce_torus-2d_all-gather_stage-1_dim-0_bidir-ring_cw" + opcode: "all-gather" + instruction_id: 95 + bytes_out: 12 + communication_groups { + group_ids: 0 + group_ids: 2 + } + inner_subroutines { + name: "all-reduce_torus-2d_all-gather_stage-1_dim-0_bidir-ring_cw_unidir-ring" + subroutine_root_id: 96 + execution_probability: 1 + execution_count: 1 + instructions { + name: "all-reduce_torus-2d_all-gather_stage-1_dim-0_bidir-ring_cw_unidir-ring_sendrecv_1" + opcode: "sendrecv" + instruction_id: 96 + bytes_in: 6 + bytes_out: 6 + communication_groups { + group_ids: 0 + group_ids: 0 + } + } + } + } + instructions { + name: "all-reduce_torus-2d_all-gather_stage-1_dim-0_bidir-ring_ccw" + opcode: "all-gather" + instruction_id: 97 + bytes_out: 12 + communication_groups { + group_ids: 2 + group_ids: 0 + } + inner_subroutines { + name: "all-reduce_torus-2d_all-gather_stage-1_dim-0_bidir-ring_ccw_unidir-ring" + subroutine_root_id: 98 + execution_probability: 1 + execution_count: 1 + instructions { + name: "all-reduce_torus-2d_all-gather_stage-1_dim-0_bidir-ring_ccw_unidir-ring_sendrecv_1" + opcode: "sendrecv" + instruction_id: 98 + bytes_in: 6 + bytes_out: 6 + communication_groups { + group_ids: 0 + group_ids: 0 + } + } + } + } + instructions { + name: "all-reduce_torus-2d_all-gather_stage-1_dim-0_bidir-ring_root_2" + opcode: "null" + instruction_id: 99 + operand_ids: 95 + operand_ids: 97 + } + } + } + instructions { + name: "all-reduce_torus-2d_all-gather_stage-1_dim-1" + opcode: "all-gather" + instruction_id: 100 + bytes_out: 24 + communication_groups { + group_ids: 2 + group_ids: 6 + } + operand_ids: 93 + inner_subroutines { + name: "all-reduce_torus-2d_all-gather_stage-1_dim-1_bidir-ring" + subroutine_root_id: 105 + execution_probability: 1 + execution_count: 1 + instructions { + name: "all-reduce_torus-2d_all-gather_stage-1_dim-1_bidir-ring_cw" + opcode: "all-gather" + instruction_id: 101 + bytes_out: 12 + communication_groups { + group_ids: 2 + group_ids: 6 + } + inner_subroutines { + name: "all-reduce_torus-2d_all-gather_stage-1_dim-1_bidir-ring_cw_unidir-ring" + subroutine_root_id: 102 + execution_probability: 1 + execution_count: 1 + instructions { + name: "all-reduce_torus-2d_all-gather_stage-1_dim-1_bidir-ring_cw_unidir-ring_sendrecv_1" + opcode: "sendrecv" + instruction_id: 102 + bytes_in: 6 + bytes_out: 6 + communication_groups { + group_ids: 6 + group_ids: 6 + } + } + } + } + instructions { + name: "all-reduce_torus-2d_all-gather_stage-1_dim-1_bidir-ring_ccw" + opcode: "all-gather" + instruction_id: 103 + bytes_out: 12 + communication_groups { + group_ids: 6 + group_ids: 2 + } + inner_subroutines { + name: "all-reduce_torus-2d_all-gather_stage-1_dim-1_bidir-ring_ccw_unidir-ring" + subroutine_root_id: 104 + execution_probability: 1 + execution_count: 1 + instructions { + name: "all-reduce_torus-2d_all-gather_stage-1_dim-1_bidir-ring_ccw_unidir-ring_sendrecv_1" + opcode: "sendrecv" + instruction_id: 104 + bytes_in: 6 + bytes_out: 6 + communication_groups { + group_ids: 6 + group_ids: 6 + } + } + } + } + instructions { + name: "all-reduce_torus-2d_all-gather_stage-1_dim-1_bidir-ring_root_2" + opcode: "null" + instruction_id: 105 + operand_ids: 101 + operand_ids: 103 + } + } + } + instructions { + name: "all-reduce_torus-2d_all-gather_stage-1_root" + opcode: "null" + instruction_id: 106 + operand_ids: 94 + operand_ids: 100 + } + instructions { + name: "all-reduce_torus-2d_all-gather_conc" + opcode: "all-gather" + instruction_id: 107 + bytes_out: 48 + communication_groups { + group_ids: 2 + group_ids: 3 + } + operand_ids: 106 + inner_subroutines { + name: "all-reduce_torus-2d_all-gather_conc_bidir-ring" + subroutine_root_id: 112 + execution_probability: 1 + execution_count: 1 + instructions { + name: "all-reduce_torus-2d_all-gather_conc_bidir-ring_cw" + opcode: "all-gather" + instruction_id: 108 + bytes_out: 24 + communication_groups { + group_ids: 2 + group_ids: 3 + } + inner_subroutines { + name: "all-reduce_torus-2d_all-gather_conc_bidir-ring_cw_unidir-ring" + subroutine_root_id: 109 + execution_probability: 1 + execution_count: 1 + instructions { + name: "all-reduce_torus-2d_all-gather_conc_bidir-ring_cw_unidir-ring_sendrecv_1" + opcode: "sendrecv" + instruction_id: 109 + bytes_in: 12 + bytes_out: 12 + communication_groups { + group_ids: 3 + group_ids: 3 + } + } + } + } + instructions { + name: "all-reduce_torus-2d_all-gather_conc_bidir-ring_ccw" + opcode: "all-gather" + instruction_id: 110 + bytes_out: 24 + communication_groups { + group_ids: 3 + group_ids: 2 + } + inner_subroutines { + name: "all-reduce_torus-2d_all-gather_conc_bidir-ring_ccw_unidir-ring" + subroutine_root_id: 111 + execution_probability: 1 + execution_count: 1 + instructions { + name: "all-reduce_torus-2d_all-gather_conc_bidir-ring_ccw_unidir-ring_sendrecv_1" + opcode: "sendrecv" + instruction_id: 111 + bytes_in: 12 + bytes_out: 12 + communication_groups { + group_ids: 3 + group_ids: 3 + } + } + } + } + instructions { + name: "all-reduce_torus-2d_all-gather_conc_bidir-ring_root_2" + opcode: "null" + instruction_id: 112 + operand_ids: 108 + operand_ids: 110 + } + } + } + } + } +} + )proto"; + google::protobuf::TextFormat::ParseFromString(test_str, + &proto); + return proto; +} // NOLINT + +// Tests expanding 2D-Torus all-reduce +TEST(Torus2dAllReduce, NoBarrier) { + auto graph = absl::make_unique("test_graph", 2); + auto sub = absl::make_unique( + "test_subroutine", graph.get()); + auto sub_ptr = sub.get(); + sub_ptr->SetId(3); + graph->SetEntrySubroutine(std::move(sub)); + + ASSERT_OK_AND_ASSIGN(auto instr_1, paragraph::Instruction::Create( + paragraph::Opcode::kDelay, "first_instruction", sub_ptr)); + instr_1->SetOps(4); + + ASSERT_OK_AND_ASSIGN(auto allreduce, paragraph::Instruction::Create( + paragraph::Opcode::kAllReduce, "all-reduce", sub_ptr)); + allreduce->SetBytesOut(48); + paragraph::CommunicationGroup allreduce_group = {0, 1, 2, 3, 4, 5, 6, 7}; + allreduce->AppendCommunicationGroup(allreduce_group); + + auto reduction_sub = absl::make_unique( + "reduction_subroutine", graph.get()); + auto reduction_ptr = reduction_sub.get(); + ASSERT_OK_AND_ASSIGN(auto op1, paragraph::Instruction::Create( + paragraph::Opcode::kDelay, "op1", reduction_ptr)); + op1->SetBytesOut(48); + ASSERT_OK_AND_ASSIGN(auto op2, paragraph::Instruction::Create( + paragraph::Opcode::kDelay, "op2", reduction_ptr)); + op2->SetBytesOut(48); + ASSERT_OK_AND_ASSIGN(auto sum_op, paragraph::Instruction::Create( + paragraph::Opcode::kDelay, "sum", reduction_ptr, true)); + sum_op->SetOps(96); + sum_op->AddOperand(op1); + sum_op->AddOperand(op2); + allreduce->AppendInnerSubroutine(std::move(reduction_sub)); + + ASSERT_OK_AND_ASSIGN(auto instr_3, paragraph::Instruction::Create( + paragraph::Opcode::kDelay, "last_instruction", sub_ptr, true)); + instr_3->SetOps(4); + + nlohmann::json config = R"( + { + "all-reduce": { + "algorithm": "torus-2d", + "concentration": 2, + "dimension_widths": [2, 2] + } + } + )"_json; + + ASSERT_OK_AND_ASSIGN(auto translators, paragraph::CreateTranslators( + paragraph::TranslatorType::kCollective, config)); + EXPECT_OK(translators["all-reduce"]->Translate(allreduce)); + + paragraph::InstructionProto allreduce_proto = no_barrier_test_proto(); + EXPECT_TRUE(google::protobuf::util::MessageDifferencer::Equals( + allreduce->ToProto().value(), allreduce_proto)); +} + +paragraph::InstructionProto with_barrier_test_proto() { + paragraph::InstructionProto proto; + std::string test_str = + R"proto( +name: "all-reduce" +opcode: "all-reduce" +instruction_id: 2 +bytes_out: 48 +communication_groups { + group_ids: 0 + group_ids: 1 + group_ids: 2 + group_ids: 3 + group_ids: 4 + group_ids: 5 + group_ids: 6 + group_ids: 7 +} +inner_subroutines { + name: "all-reduce_torus-2d" + subroutine_root_id: 80 + execution_probability: 1 + execution_count: 1 + instructions { + name: "all-reduce_torus-2d_reduce-scatter" + opcode: "reduce-scatter" + instruction_id: 7 + bytes_out: 48 + communication_groups { + group_ids: 0 + group_ids: 1 + group_ids: 2 + group_ids: 3 + group_ids: 4 + group_ids: 5 + group_ids: 6 + group_ids: 7 + } + inner_subroutines { + name: "all-reduce_torus-2d_reduce-scatter_torus-2d" + subroutine_root_id: 79 + execution_probability: 1 + execution_count: 1 + instructions { + name: "all-reduce_torus-2d_reduce-scatter_stage-0_dim-0" + opcode: "reduce-scatter" + instruction_id: 8 + bytes_out: 12 + communication_groups { + group_ids: 2 + group_ids: 3 + } + inner_subroutines { + name: "all-reduce_torus-2d_reduce-scatter_stage-0_dim-0_bidir-ring" + subroutine_root_id: 25 + execution_probability: 1 + execution_count: 1 + instructions { + name: "all-reduce_torus-2d_reduce-scatter_stage-0_dim-0_bidir-ring_barrier" + opcode: "barrier" + instruction_id: 9 + communication_groups { + group_ids: 2 + group_ids: 3 + } + inner_subroutines { + name: "all-reduce_torus-2d_reduce-scatter_stage-0_dim-0_bidir-ring_barrier_centralized" + subroutine_root_id: 12 + execution_probability: 1 + execution_count: 1 + instructions { + name: "all-reduce_torus-2d_reduce-scatter_stage-0_dim-0_bidir-ring_barrier_centralized_coordinator_recv_from_3" + opcode: "recv" + instruction_id: 10 + communication_groups { + group_ids: 3 + } + } + instructions { + name: "all-reduce_torus-2d_reduce-scatter_stage-0_dim-0_bidir-ring_barrier_centralized_coordinator_send_to_3" + opcode: "send" + instruction_id: 11 + communication_groups { + group_ids: 3 + } + operand_ids: 10 + } + instructions { + name: "all-reduce_torus-2d_reduce-scatter_stage-0_dim-0_bidir-ring_barrier_centralized_root_2" + opcode: "null" + instruction_id: 12 + operand_ids: 11 + } + } + } + instructions { + name: "all-reduce_torus-2d_reduce-scatter_stage-0_dim-0_bidir-ring_cw" + opcode: "reduce-scatter" + instruction_id: 13 + bytes_out: 6 + communication_groups { + group_ids: 2 + group_ids: 3 + } + operand_ids: 9 + inner_subroutines { + name: "all-reduce_torus-2d_reduce-scatter_stage-0_dim-0_bidir-ring_cw_unidir-ring" + subroutine_root_id: 15 + execution_probability: 1 + execution_count: 1 + instructions { + name: "all-reduce_torus-2d_reduce-scatter_stage-0_dim-0_bidir-ring_cw_unidir-ring_sendrecv_1" + opcode: "sendrecv" + instruction_id: 14 + bytes_in: 3 + bytes_out: 3 + communication_groups { + group_ids: 3 + group_ids: 3 + } + } + instructions { + name: "all-reduce_torus-2d_reduce-scatter_stage-0_dim-0_bidir-ring_cw_unidir-ring_reduction_1" + opcode: "call" + instruction_id: 15 + operand_ids: 14 + inner_subroutines { + name: "reduction_subroutine_phase_1" + subroutine_root_id: 18 + execution_probability: 1 + execution_count: 1 + instructions { + name: "op1_phase_1" + opcode: "delay" + instruction_id: 16 + bytes_out: 3 + } + instructions { + name: "op2_phase_1" + opcode: "delay" + instruction_id: 17 + bytes_out: 3 + } + instructions { + name: "sum_phase_1" + opcode: "delay" + instruction_id: 18 + ops: 6 + operand_ids: 16 + operand_ids: 17 + } + } + } + } + } + instructions { + name: "all-reduce_torus-2d_reduce-scatter_stage-0_dim-0_bidir-ring_ccw" + opcode: "reduce-scatter" + instruction_id: 19 + bytes_out: 6 + communication_groups { + group_ids: 3 + group_ids: 2 + } + operand_ids: 9 + inner_subroutines { + name: "all-reduce_torus-2d_reduce-scatter_stage-0_dim-0_bidir-ring_ccw_unidir-ring" + subroutine_root_id: 21 + execution_probability: 1 + execution_count: 1 + instructions { + name: "all-reduce_torus-2d_reduce-scatter_stage-0_dim-0_bidir-ring_ccw_unidir-ring_sendrecv_1" + opcode: "sendrecv" + instruction_id: 20 + bytes_in: 3 + bytes_out: 3 + communication_groups { + group_ids: 3 + group_ids: 3 + } + } + instructions { + name: "all-reduce_torus-2d_reduce-scatter_stage-0_dim-0_bidir-ring_ccw_unidir-ring_reduction_1" + opcode: "call" + instruction_id: 21 + operand_ids: 20 + inner_subroutines { + name: "reduction_subroutine_phase_1" + subroutine_root_id: 24 + execution_probability: 1 + execution_count: 1 + instructions { + name: "op1_phase_1" + opcode: "delay" + instruction_id: 22 + bytes_out: 3 + } + instructions { + name: "op2_phase_1" + opcode: "delay" + instruction_id: 23 + bytes_out: 3 + } + instructions { + name: "sum_phase_1" + opcode: "delay" + instruction_id: 24 + ops: 6 + operand_ids: 22 + operand_ids: 23 + } + } + } + } + } + instructions { + name: "all-reduce_torus-2d_reduce-scatter_stage-0_dim-0_bidir-ring_root_2" + opcode: "null" + instruction_id: 25 + operand_ids: 13 + operand_ids: 19 + } + } + } + instructions { + name: "all-reduce_torus-2d_reduce-scatter_stage-0_dim-1" + opcode: "reduce-scatter" + instruction_id: 26 + bytes_out: 12 + communication_groups { + group_ids: 0 + group_ids: 2 + } + inner_subroutines { + name: "all-reduce_torus-2d_reduce-scatter_stage-0_dim-1_bidir-ring" + subroutine_root_id: 42 + execution_probability: 1 + execution_count: 1 + instructions { + name: "all-reduce_torus-2d_reduce-scatter_stage-0_dim-1_bidir-ring_barrier" + opcode: "barrier" + instruction_id: 27 + communication_groups { + group_ids: 0 + group_ids: 2 + } + inner_subroutines { + name: "all-reduce_torus-2d_reduce-scatter_stage-0_dim-1_bidir-ring_barrier_centralized" + subroutine_root_id: 29 + execution_probability: 1 + execution_count: 1 + instructions { + name: "all-reduce_torus-2d_reduce-scatter_stage-0_dim-1_bidir-ring_barrier_centralized_send_to_0" + opcode: "send" + instruction_id: 28 + communication_groups { + group_ids: 0 + } + } + instructions { + name: "all-reduce_torus-2d_reduce-scatter_stage-0_dim-1_bidir-ring_barrier_centralized_recv_from_0" + opcode: "recv" + instruction_id: 29 + communication_groups { + group_ids: 0 + } + operand_ids: 28 + } + } + } + instructions { + name: "all-reduce_torus-2d_reduce-scatter_stage-0_dim-1_bidir-ring_cw" + opcode: "reduce-scatter" + instruction_id: 30 + bytes_out: 6 + communication_groups { + group_ids: 0 + group_ids: 2 + } + operand_ids: 27 + inner_subroutines { + name: "all-reduce_torus-2d_reduce-scatter_stage-0_dim-1_bidir-ring_cw_unidir-ring" + subroutine_root_id: 32 + execution_probability: 1 + execution_count: 1 + instructions { + name: "all-reduce_torus-2d_reduce-scatter_stage-0_dim-1_bidir-ring_cw_unidir-ring_sendrecv_1" + opcode: "sendrecv" + instruction_id: 31 + bytes_in: 3 + bytes_out: 3 + communication_groups { + group_ids: 0 + group_ids: 0 + } + } + instructions { + name: "all-reduce_torus-2d_reduce-scatter_stage-0_dim-1_bidir-ring_cw_unidir-ring_reduction_1" + opcode: "call" + instruction_id: 32 + operand_ids: 31 + inner_subroutines { + name: "reduction_subroutine_phase_1" + subroutine_root_id: 35 + execution_probability: 1 + execution_count: 1 + instructions { + name: "op1_phase_1" + opcode: "delay" + instruction_id: 33 + bytes_out: 3 + } + instructions { + name: "op2_phase_1" + opcode: "delay" + instruction_id: 34 + bytes_out: 3 + } + instructions { + name: "sum_phase_1" + opcode: "delay" + instruction_id: 35 + ops: 6 + operand_ids: 33 + operand_ids: 34 + } + } + } + } + } + instructions { + name: "all-reduce_torus-2d_reduce-scatter_stage-0_dim-1_bidir-ring_ccw" + opcode: "reduce-scatter" + instruction_id: 36 + bytes_out: 6 + communication_groups { + group_ids: 2 + group_ids: 0 + } + operand_ids: 27 inner_subroutines { - name: "all-reduce_torus-2d_all-gather_dim-0_bidir-ring_ccw_unidir-ring" - subroutine_root_id: 61 + name: "all-reduce_torus-2d_reduce-scatter_stage-0_dim-1_bidir-ring_ccw_unidir-ring" + subroutine_root_id: 38 execution_probability: 1 execution_count: 1 instructions { - name: "all-reduce_torus-2d_all-gather_dim-0_bidir-ring_ccw_unidir-ring_sendrecv_1" + name: "all-reduce_torus-2d_reduce-scatter_stage-0_dim-1_bidir-ring_ccw_unidir-ring_sendrecv_1" opcode: "sendrecv" - instruction_id: 61 + instruction_id: 37 bytes_in: 3 bytes_out: 3 communication_groups { @@ -711,186 +1637,427 @@ inner_subroutines { group_ids: 0 } } + instructions { + name: "all-reduce_torus-2d_reduce-scatter_stage-0_dim-1_bidir-ring_ccw_unidir-ring_reduction_1" + opcode: "call" + instruction_id: 38 + operand_ids: 37 + inner_subroutines { + name: "reduction_subroutine_phase_1" + subroutine_root_id: 41 + execution_probability: 1 + execution_count: 1 + instructions { + name: "op1_phase_1" + opcode: "delay" + instruction_id: 39 + bytes_out: 3 + } + instructions { + name: "op2_phase_1" + opcode: "delay" + instruction_id: 40 + bytes_out: 3 + } + instructions { + name: "sum_phase_1" + opcode: "delay" + instruction_id: 41 + ops: 6 + operand_ids: 39 + operand_ids: 40 + } + } + } + } + } + instructions { + name: "all-reduce_torus-2d_reduce-scatter_stage-0_dim-1_bidir-ring_root_2" + opcode: "null" + instruction_id: 42 + operand_ids: 30 + operand_ids: 36 + } + } + } + instructions { + name: "all-reduce_torus-2d_reduce-scatter_stage-0_root" + opcode: "null" + instruction_id: 43 + operand_ids: 8 + operand_ids: 26 + } + instructions { + name: "all-reduce_torus-2d_reduce-scatter_stage-1_dim-0" + opcode: "reduce-scatter" + instruction_id: 44 + bytes_out: 24 + communication_groups { + group_ids: 2 + group_ids: 3 + } + operand_ids: 43 + inner_subroutines { + name: "all-reduce_torus-2d_reduce-scatter_stage-1_dim-0_bidir-ring" + subroutine_root_id: 61 + execution_probability: 1 + execution_count: 1 + instructions { + name: "all-reduce_torus-2d_reduce-scatter_stage-1_dim-0_bidir-ring_barrier" + opcode: "barrier" + instruction_id: 45 + communication_groups { + group_ids: 2 + group_ids: 3 + } + inner_subroutines { + name: "all-reduce_torus-2d_reduce-scatter_stage-1_dim-0_bidir-ring_barrier_centralized" + subroutine_root_id: 48 + execution_probability: 1 + execution_count: 1 + instructions { + name: "all-reduce_torus-2d_reduce-scatter_stage-1_dim-0_bidir-ring_barrier_centralized_coordinator_recv_from_3" + opcode: "recv" + instruction_id: 46 + communication_groups { + group_ids: 3 + } + } + instructions { + name: "all-reduce_torus-2d_reduce-scatter_stage-1_dim-0_bidir-ring_barrier_centralized_coordinator_send_to_3" + opcode: "send" + instruction_id: 47 + communication_groups { + group_ids: 3 + } + operand_ids: 46 + } + instructions { + name: "all-reduce_torus-2d_reduce-scatter_stage-1_dim-0_bidir-ring_barrier_centralized_root_2" + opcode: "null" + instruction_id: 48 + operand_ids: 47 + } + } + } + instructions { + name: "all-reduce_torus-2d_reduce-scatter_stage-1_dim-0_bidir-ring_cw" + opcode: "reduce-scatter" + instruction_id: 49 + bytes_out: 12 + communication_groups { + group_ids: 2 + group_ids: 3 + } + operand_ids: 45 + inner_subroutines { + name: "all-reduce_torus-2d_reduce-scatter_stage-1_dim-0_bidir-ring_cw_unidir-ring" + subroutine_root_id: 51 + execution_probability: 1 + execution_count: 1 + instructions { + name: "all-reduce_torus-2d_reduce-scatter_stage-1_dim-0_bidir-ring_cw_unidir-ring_sendrecv_1" + opcode: "sendrecv" + instruction_id: 50 + bytes_in: 6 + bytes_out: 6 + communication_groups { + group_ids: 3 + group_ids: 3 + } + } + instructions { + name: "all-reduce_torus-2d_reduce-scatter_stage-1_dim-0_bidir-ring_cw_unidir-ring_reduction_1" + opcode: "call" + instruction_id: 51 + operand_ids: 50 + inner_subroutines { + name: "reduction_subroutine_phase_1" + subroutine_root_id: 54 + execution_probability: 1 + execution_count: 1 + instructions { + name: "op1_phase_1" + opcode: "delay" + instruction_id: 52 + bytes_out: 6 + } + instructions { + name: "op2_phase_1" + opcode: "delay" + instruction_id: 53 + bytes_out: 6 + } + instructions { + name: "sum_phase_1" + opcode: "delay" + instruction_id: 54 + ops: 12 + operand_ids: 52 + operand_ids: 53 + } + } + } + } + } + instructions { + name: "all-reduce_torus-2d_reduce-scatter_stage-1_dim-0_bidir-ring_ccw" + opcode: "reduce-scatter" + instruction_id: 55 + bytes_out: 12 + communication_groups { + group_ids: 3 + group_ids: 2 + } + operand_ids: 45 + inner_subroutines { + name: "all-reduce_torus-2d_reduce-scatter_stage-1_dim-0_bidir-ring_ccw_unidir-ring" + subroutine_root_id: 57 + execution_probability: 1 + execution_count: 1 + instructions { + name: "all-reduce_torus-2d_reduce-scatter_stage-1_dim-0_bidir-ring_ccw_unidir-ring_sendrecv_1" + opcode: "sendrecv" + instruction_id: 56 + bytes_in: 6 + bytes_out: 6 + communication_groups { + group_ids: 3 + group_ids: 3 + } + } + instructions { + name: "all-reduce_torus-2d_reduce-scatter_stage-1_dim-0_bidir-ring_ccw_unidir-ring_reduction_1" + opcode: "call" + instruction_id: 57 + operand_ids: 56 + inner_subroutines { + name: "reduction_subroutine_phase_1" + subroutine_root_id: 60 + execution_probability: 1 + execution_count: 1 + instructions { + name: "op1_phase_1" + opcode: "delay" + instruction_id: 58 + bytes_out: 6 + } + instructions { + name: "op2_phase_1" + opcode: "delay" + instruction_id: 59 + bytes_out: 6 + } + instructions { + name: "sum_phase_1" + opcode: "delay" + instruction_id: 60 + ops: 12 + operand_ids: 58 + operand_ids: 59 + } + } + } } } instructions { - name: "all-reduce_torus-2d_all-gather_dim-0_bidir-ring_root_2" + name: "all-reduce_torus-2d_reduce-scatter_stage-1_dim-0_bidir-ring_root_2" opcode: "null" - instruction_id: 62 - operand_ids: 58 - operand_ids: 60 + instruction_id: 61 + operand_ids: 49 + operand_ids: 55 } } } instructions { - name: "all-reduce_torus-2d_all-gather_dim-1" - opcode: "all-gather" - instruction_id: 63 - bytes_out: 12 + name: "all-reduce_torus-2d_reduce-scatter_stage-1_dim-1" + opcode: "reduce-scatter" + instruction_id: 62 + bytes_out: 24 communication_groups { + group_ids: 0 group_ids: 2 - group_ids: 6 } - operand_ids: 57 + operand_ids: 43 inner_subroutines { - name: "all-reduce_torus-2d_all-gather_dim-1_bidir-ring" - subroutine_root_id: 68 + name: "all-reduce_torus-2d_reduce-scatter_stage-1_dim-1_bidir-ring" + subroutine_root_id: 78 execution_probability: 1 execution_count: 1 instructions { - name: "all-reduce_torus-2d_all-gather_dim-1_bidir-ring_cw" - opcode: "all-gather" - instruction_id: 64 - bytes_out: 6 + name: "all-reduce_torus-2d_reduce-scatter_stage-1_dim-1_bidir-ring_barrier" + opcode: "barrier" + instruction_id: 63 communication_groups { + group_ids: 0 group_ids: 2 - group_ids: 6 } inner_subroutines { - name: "all-reduce_torus-2d_all-gather_dim-1_bidir-ring_cw_unidir-ring" + name: "all-reduce_torus-2d_reduce-scatter_stage-1_dim-1_bidir-ring_barrier_centralized" subroutine_root_id: 65 execution_probability: 1 execution_count: 1 instructions { - name: "all-reduce_torus-2d_all-gather_dim-1_bidir-ring_cw_unidir-ring_sendrecv_1" - opcode: "sendrecv" + name: "all-reduce_torus-2d_reduce-scatter_stage-1_dim-1_bidir-ring_barrier_centralized_send_to_0" + opcode: "send" + instruction_id: 64 + communication_groups { + group_ids: 0 + } + } + instructions { + name: "all-reduce_torus-2d_reduce-scatter_stage-1_dim-1_bidir-ring_barrier_centralized_recv_from_0" + opcode: "recv" instruction_id: 65 - bytes_in: 3 - bytes_out: 3 communication_groups { - group_ids: 6 - group_ids: 6 + group_ids: 0 } + operand_ids: 64 } } } instructions { - name: "all-reduce_torus-2d_all-gather_dim-1_bidir-ring_ccw" - opcode: "all-gather" + name: "all-reduce_torus-2d_reduce-scatter_stage-1_dim-1_bidir-ring_cw" + opcode: "reduce-scatter" instruction_id: 66 - bytes_out: 6 + bytes_out: 12 communication_groups { - group_ids: 6 + group_ids: 0 group_ids: 2 } + operand_ids: 63 inner_subroutines { - name: "all-reduce_torus-2d_all-gather_dim-1_bidir-ring_ccw_unidir-ring" - subroutine_root_id: 67 + name: "all-reduce_torus-2d_reduce-scatter_stage-1_dim-1_bidir-ring_cw_unidir-ring" + subroutine_root_id: 68 execution_probability: 1 execution_count: 1 instructions { - name: "all-reduce_torus-2d_all-gather_dim-1_bidir-ring_ccw_unidir-ring_sendrecv_1" + name: "all-reduce_torus-2d_reduce-scatter_stage-1_dim-1_bidir-ring_cw_unidir-ring_sendrecv_1" opcode: "sendrecv" instruction_id: 67 - bytes_in: 3 - bytes_out: 3 + bytes_in: 6 + bytes_out: 6 communication_groups { - group_ids: 6 - group_ids: 6 + group_ids: 0 + group_ids: 0 + } + } + instructions { + name: "all-reduce_torus-2d_reduce-scatter_stage-1_dim-1_bidir-ring_cw_unidir-ring_reduction_1" + opcode: "call" + instruction_id: 68 + operand_ids: 67 + inner_subroutines { + name: "reduction_subroutine_phase_1" + subroutine_root_id: 71 + execution_probability: 1 + execution_count: 1 + instructions { + name: "op1_phase_1" + opcode: "delay" + instruction_id: 69 + bytes_out: 6 + } + instructions { + name: "op2_phase_1" + opcode: "delay" + instruction_id: 70 + bytes_out: 6 + } + instructions { + name: "sum_phase_1" + opcode: "delay" + instruction_id: 71 + ops: 12 + operand_ids: 69 + operand_ids: 70 + } + } + } + } + } + instructions { + name: "all-reduce_torus-2d_reduce-scatter_stage-1_dim-1_bidir-ring_ccw" + opcode: "reduce-scatter" + instruction_id: 72 + bytes_out: 12 + communication_groups { + group_ids: 2 + group_ids: 0 + } + operand_ids: 63 + inner_subroutines { + name: "all-reduce_torus-2d_reduce-scatter_stage-1_dim-1_bidir-ring_ccw_unidir-ring" + subroutine_root_id: 74 + execution_probability: 1 + execution_count: 1 + instructions { + name: "all-reduce_torus-2d_reduce-scatter_stage-1_dim-1_bidir-ring_ccw_unidir-ring_sendrecv_1" + opcode: "sendrecv" + instruction_id: 73 + bytes_in: 6 + bytes_out: 6 + communication_groups { + group_ids: 0 + group_ids: 0 + } + } + instructions { + name: "all-reduce_torus-2d_reduce-scatter_stage-1_dim-1_bidir-ring_ccw_unidir-ring_reduction_1" + opcode: "call" + instruction_id: 74 + operand_ids: 73 + inner_subroutines { + name: "reduction_subroutine_phase_1" + subroutine_root_id: 77 + execution_probability: 1 + execution_count: 1 + instructions { + name: "op1_phase_1" + opcode: "delay" + instruction_id: 75 + bytes_out: 6 + } + instructions { + name: "op2_phase_1" + opcode: "delay" + instruction_id: 76 + bytes_out: 6 + } + instructions { + name: "sum_phase_1" + opcode: "delay" + instruction_id: 77 + ops: 12 + operand_ids: 75 + operand_ids: 76 + } } } } } instructions { - name: "all-reduce_torus-2d_all-gather_dim-1_bidir-ring_root_2" + name: "all-reduce_torus-2d_reduce-scatter_stage-1_dim-1_bidir-ring_root_2" opcode: "null" - instruction_id: 68 - operand_ids: 64 + instruction_id: 78 operand_ids: 66 + operand_ids: 72 } } } - } - } -} - )proto"; - google::protobuf::TextFormat::ParseFromString(allreduce_str, - &allreduce_proto); - EXPECT_TRUE(google::protobuf::util::MessageDifferencer::Equals( - allreduce->ToProto().value(), allreduce_proto)); -} - -// Tests expanding 2D-Torus all-reduce with barrier -TEST(Torus2dAllReduce, WithBarrier) { - auto graph = absl::make_unique("test_graph", 2); - auto sub = absl::make_unique( - "test_subroutine", graph.get()); - auto sub_ptr = sub.get(); - sub_ptr->SetId(3); - graph->SetEntrySubroutine(std::move(sub)); - - ASSERT_OK_AND_ASSIGN(auto instr_1, paragraph::Instruction::Create( - paragraph::Opcode::kDelay, "first_instruction", sub_ptr)); - instr_1->SetOps(4); - - ASSERT_OK_AND_ASSIGN(auto allreduce, paragraph::Instruction::Create( - paragraph::Opcode::kAllReduce, "all-reduce", sub_ptr)); - allreduce->SetBytesOut(48); - paragraph::CommunicationGroup allreduce_group = {0, 1, 2, 3, 4, 5, 6, 7}; - allreduce->AppendCommunicationGroup(allreduce_group); - - auto reduction_sub = absl::make_unique( - "reduction_subroutine", graph.get()); - auto reduction_ptr = reduction_sub.get(); - ASSERT_OK_AND_ASSIGN(auto op1, paragraph::Instruction::Create( - paragraph::Opcode::kDelay, "op1", reduction_ptr)); - op1->SetBytesOut(48); - ASSERT_OK_AND_ASSIGN(auto op2, paragraph::Instruction::Create( - paragraph::Opcode::kDelay, "op2", reduction_ptr)); - op2->SetBytesOut(48); - ASSERT_OK_AND_ASSIGN(auto sum_op, paragraph::Instruction::Create( - paragraph::Opcode::kDelay, "sum", reduction_ptr, true)); - sum_op->SetOps(96); - sum_op->AddOperand(op1); - sum_op->AddOperand(op2); - allreduce->AppendInnerSubroutine(std::move(reduction_sub)); - - ASSERT_OK_AND_ASSIGN(auto instr_3, paragraph::Instruction::Create( - paragraph::Opcode::kDelay, "last_instruction", sub_ptr, true)); - instr_3->SetOps(4); - - nlohmann::json config = R"( - { - "all-reduce": { - "algorithm": "torus-2d", - "dimension_widths": [2, 2], - "barrier": { - "algorithm": "centralized" - } + instructions { + name: "all-reduce_torus-2d_reduce-scatter_stage-1_root" + opcode: "null" + instruction_id: 79 + operand_ids: 44 + operand_ids: 62 } } - )"_json; - - ASSERT_OK_AND_ASSIGN(auto translators, paragraph::CreateTranslators( - paragraph::TranslatorType::kCollective, config)); - EXPECT_OK(translators["all-reduce"]->Translate(allreduce)); - - paragraph::InstructionProto allreduce_proto; - std::string allreduce_str = - R"proto( -name: "all-reduce" -opcode: "all-reduce" -instruction_id: 2 -bytes_out: 48 -communication_groups { - group_ids: 0 - group_ids: 1 - group_ids: 2 - group_ids: 3 - group_ids: 4 - group_ids: 5 - group_ids: 6 - group_ids: 7 -} -inner_subroutines { - name: "all-reduce_torus-2d" - subroutine_root_id: 43 - execution_probability: 1 - execution_count: 1 + } instructions { - name: "all-reduce_torus-2d_reduce-scatter" - opcode: "reduce-scatter" - instruction_id: 7 + name: "all-reduce_torus-2d_all-gather" + opcode: "all-gather" + instruction_id: 80 bytes_out: 48 communication_groups { group_ids: 0 @@ -902,142 +2069,111 @@ inner_subroutines { group_ids: 6 group_ids: 7 } + operand_ids: 7 inner_subroutines { - name: "all-reduce_torus-2d_reduce-scatter_torus-2d" - subroutine_root_id: 26 + name: "all-reduce_torus-2d_all-gather_torus-2d" + subroutine_root_id: 120 execution_probability: 1 execution_count: 1 instructions { - name: "all-reduce_torus-2d_reduce-scatter_dim-0" - opcode: "reduce-scatter" - instruction_id: 8 + name: "all-reduce_torus-2d_all-gather_stage-0_dim-0" + opcode: "all-gather" + instruction_id: 81 bytes_out: 12 communication_groups { group_ids: 2 group_ids: 3 } inner_subroutines { - name: "all-reduce_torus-2d_reduce-scatter_dim-0_bidir-ring" - subroutine_root_id: 25 + name: "all-reduce_torus-2d_all-gather_stage-0_dim-0_bidir-ring" + subroutine_root_id: 90 execution_probability: 1 execution_count: 1 instructions { - name: "all-reduce_torus-2d_reduce-scatter_dim-0_bidir-ring_barrier" + name: "all-reduce_torus-2d_all-gather_stage-0_dim-0_bidir-ring_barrier" opcode: "barrier" - instruction_id: 9 + instruction_id: 82 communication_groups { group_ids: 2 group_ids: 3 } inner_subroutines { - name: "all-reduce_torus-2d_reduce-scatter_dim-0_bidir-ring_barrier_centralized" - subroutine_root_id: 12 + name: "all-reduce_torus-2d_all-gather_stage-0_dim-0_bidir-ring_barrier_centralized" + subroutine_root_id: 85 execution_probability: 1 execution_count: 1 instructions { - name: "all-reduce_torus-2d_reduce-scatter_dim-0_bidir-ring_barrier_centralized_coordinator_recv_from_3" + name: "all-reduce_torus-2d_all-gather_stage-0_dim-0_bidir-ring_barrier_centralized_coordinator_recv_from_3" opcode: "recv" - instruction_id: 10 + instruction_id: 83 communication_groups { group_ids: 3 } } instructions { - name: "all-reduce_torus-2d_reduce-scatter_dim-0_bidir-ring_barrier_centralized_coordinator_send_to_3" + name: "all-reduce_torus-2d_all-gather_stage-0_dim-0_bidir-ring_barrier_centralized_coordinator_send_to_3" opcode: "send" - instruction_id: 11 + instruction_id: 84 communication_groups { group_ids: 3 } - operand_ids: 10 + operand_ids: 83 } instructions { - name: "all-reduce_torus-2d_reduce-scatter_dim-0_bidir-ring_barrier_centralized_root_2" + name: "all-reduce_torus-2d_all-gather_stage-0_dim-0_bidir-ring_barrier_centralized_root_2" opcode: "null" - instruction_id: 12 - operand_ids: 11 + instruction_id: 85 + operand_ids: 84 } } } instructions { - name: "all-reduce_torus-2d_reduce-scatter_dim-0_bidir-ring_cw" - opcode: "reduce-scatter" - instruction_id: 13 + name: "all-reduce_torus-2d_all-gather_stage-0_dim-0_bidir-ring_cw" + opcode: "all-gather" + instruction_id: 86 bytes_out: 6 communication_groups { group_ids: 2 group_ids: 3 } - operand_ids: 9 - inner_subroutines { - name: "all-reduce_torus-2d_reduce-scatter_dim-0_bidir-ring_cw_unidir-ring" - subroutine_root_id: 15 - execution_probability: 1 - execution_count: 1 - instructions { - name: "all-reduce_torus-2d_reduce-scatter_dim-0_bidir-ring_cw_unidir-ring_sendrecv_1" - opcode: "sendrecv" - instruction_id: 14 - bytes_in: 3 - bytes_out: 3 - communication_groups { - group_ids: 3 - group_ids: 3 - } - } - instructions { - name: "all-reduce_torus-2d_reduce-scatter_dim-0_bidir-ring_cw_unidir-ring_reduction_1" - opcode: "call" - instruction_id: 15 - operand_ids: 14 - inner_subroutines { - name: "reduction_subroutine_phase_1" - subroutine_root_id: 18 - execution_probability: 1 - execution_count: 1 - instructions { - name: "op1_phase_1" - opcode: "delay" - instruction_id: 16 - bytes_out: 3 - } - instructions { - name: "op2_phase_1" - opcode: "delay" - instruction_id: 17 - bytes_out: 3 - } - instructions { - name: "sum_phase_1" - opcode: "delay" - instruction_id: 18 - ops: 6 - operand_ids: 16 - operand_ids: 17 - } + operand_ids: 82 + inner_subroutines { + name: "all-reduce_torus-2d_all-gather_stage-0_dim-0_bidir-ring_cw_unidir-ring" + subroutine_root_id: 87 + execution_probability: 1 + execution_count: 1 + instructions { + name: "all-reduce_torus-2d_all-gather_stage-0_dim-0_bidir-ring_cw_unidir-ring_sendrecv_1" + opcode: "sendrecv" + instruction_id: 87 + bytes_in: 3 + bytes_out: 3 + communication_groups { + group_ids: 3 + group_ids: 3 } } } } instructions { - name: "all-reduce_torus-2d_reduce-scatter_dim-0_bidir-ring_ccw" - opcode: "reduce-scatter" - instruction_id: 19 + name: "all-reduce_torus-2d_all-gather_stage-0_dim-0_bidir-ring_ccw" + opcode: "all-gather" + instruction_id: 88 bytes_out: 6 communication_groups { group_ids: 3 group_ids: 2 } - operand_ids: 9 + operand_ids: 82 inner_subroutines { - name: "all-reduce_torus-2d_reduce-scatter_dim-0_bidir-ring_ccw_unidir-ring" - subroutine_root_id: 21 + name: "all-reduce_torus-2d_all-gather_stage-0_dim-0_bidir-ring_ccw_unidir-ring" + subroutine_root_id: 89 execution_probability: 1 execution_count: 1 instructions { - name: "all-reduce_torus-2d_reduce-scatter_dim-0_bidir-ring_ccw_unidir-ring_sendrecv_1" + name: "all-reduce_torus-2d_all-gather_stage-0_dim-0_bidir-ring_ccw_unidir-ring_sendrecv_1" opcode: "sendrecv" - instruction_id: 20 + instruction_id: 89 bytes_in: 3 bytes_out: 3 communication_groups { @@ -1045,115 +2181,82 @@ inner_subroutines { group_ids: 3 } } - instructions { - name: "all-reduce_torus-2d_reduce-scatter_dim-0_bidir-ring_ccw_unidir-ring_reduction_1" - opcode: "call" - instruction_id: 21 - operand_ids: 20 - inner_subroutines { - name: "reduction_subroutine_phase_1" - subroutine_root_id: 24 - execution_probability: 1 - execution_count: 1 - instructions { - name: "op1_phase_1" - opcode: "delay" - instruction_id: 22 - bytes_out: 3 - } - instructions { - name: "op2_phase_1" - opcode: "delay" - instruction_id: 23 - bytes_out: 3 - } - instructions { - name: "sum_phase_1" - opcode: "delay" - instruction_id: 24 - ops: 6 - operand_ids: 22 - operand_ids: 23 - } - } - } } } instructions { - name: "all-reduce_torus-2d_reduce-scatter_dim-0_bidir-ring_root_2" + name: "all-reduce_torus-2d_all-gather_stage-0_dim-0_bidir-ring_root_2" opcode: "null" - instruction_id: 25 - operand_ids: 13 - operand_ids: 19 + instruction_id: 90 + operand_ids: 86 + operand_ids: 88 } } } instructions { - name: "all-reduce_torus-2d_reduce-scatter_dim-1" - opcode: "reduce-scatter" - instruction_id: 26 + name: "all-reduce_torus-2d_all-gather_stage-0_dim-1" + opcode: "all-gather" + instruction_id: 91 bytes_out: 12 communication_groups { group_ids: 0 group_ids: 2 } - operand_ids: 8 inner_subroutines { - name: "all-reduce_torus-2d_reduce-scatter_dim-1_bidir-ring" - subroutine_root_id: 42 + name: "all-reduce_torus-2d_all-gather_stage-0_dim-1_bidir-ring" + subroutine_root_id: 99 execution_probability: 1 execution_count: 1 instructions { - name: "all-reduce_torus-2d_reduce-scatter_dim-1_bidir-ring_barrier" + name: "all-reduce_torus-2d_all-gather_stage-0_dim-1_bidir-ring_barrier" opcode: "barrier" - instruction_id: 27 + instruction_id: 92 communication_groups { group_ids: 0 group_ids: 2 } inner_subroutines { - name: "all-reduce_torus-2d_reduce-scatter_dim-1_bidir-ring_barrier_centralized" - subroutine_root_id: 29 + name: "all-reduce_torus-2d_all-gather_stage-0_dim-1_bidir-ring_barrier_centralized" + subroutine_root_id: 94 execution_probability: 1 execution_count: 1 instructions { - name: "all-reduce_torus-2d_reduce-scatter_dim-1_bidir-ring_barrier_centralized_send_to_0" + name: "all-reduce_torus-2d_all-gather_stage-0_dim-1_bidir-ring_barrier_centralized_send_to_0" opcode: "send" - instruction_id: 28 + instruction_id: 93 communication_groups { group_ids: 0 } } instructions { - name: "all-reduce_torus-2d_reduce-scatter_dim-1_bidir-ring_barrier_centralized_recv_from_0" + name: "all-reduce_torus-2d_all-gather_stage-0_dim-1_bidir-ring_barrier_centralized_recv_from_0" opcode: "recv" - instruction_id: 29 + instruction_id: 94 communication_groups { group_ids: 0 } - operand_ids: 28 + operand_ids: 93 } } } instructions { - name: "all-reduce_torus-2d_reduce-scatter_dim-1_bidir-ring_cw" - opcode: "reduce-scatter" - instruction_id: 30 + name: "all-reduce_torus-2d_all-gather_stage-0_dim-1_bidir-ring_cw" + opcode: "all-gather" + instruction_id: 95 bytes_out: 6 communication_groups { group_ids: 0 group_ids: 2 } - operand_ids: 27 + operand_ids: 92 inner_subroutines { - name: "all-reduce_torus-2d_reduce-scatter_dim-1_bidir-ring_cw_unidir-ring" - subroutine_root_id: 32 + name: "all-reduce_torus-2d_all-gather_stage-0_dim-1_bidir-ring_cw_unidir-ring" + subroutine_root_id: 96 execution_probability: 1 execution_count: 1 instructions { - name: "all-reduce_torus-2d_reduce-scatter_dim-1_bidir-ring_cw_unidir-ring_sendrecv_1" + name: "all-reduce_torus-2d_all-gather_stage-0_dim-1_bidir-ring_cw_unidir-ring_sendrecv_1" opcode: "sendrecv" - instruction_id: 31 + instruction_id: 96 bytes_in: 3 bytes_out: 3 communication_groups { @@ -1161,59 +2264,27 @@ inner_subroutines { group_ids: 0 } } - instructions { - name: "all-reduce_torus-2d_reduce-scatter_dim-1_bidir-ring_cw_unidir-ring_reduction_1" - opcode: "call" - instruction_id: 32 - operand_ids: 31 - inner_subroutines { - name: "reduction_subroutine_phase_1" - subroutine_root_id: 35 - execution_probability: 1 - execution_count: 1 - instructions { - name: "op1_phase_1" - opcode: "delay" - instruction_id: 33 - bytes_out: 3 - } - instructions { - name: "op2_phase_1" - opcode: "delay" - instruction_id: 34 - bytes_out: 3 - } - instructions { - name: "sum_phase_1" - opcode: "delay" - instruction_id: 35 - ops: 6 - operand_ids: 33 - operand_ids: 34 - } - } - } } } instructions { - name: "all-reduce_torus-2d_reduce-scatter_dim-1_bidir-ring_ccw" - opcode: "reduce-scatter" - instruction_id: 36 + name: "all-reduce_torus-2d_all-gather_stage-0_dim-1_bidir-ring_ccw" + opcode: "all-gather" + instruction_id: 97 bytes_out: 6 communication_groups { group_ids: 2 group_ids: 0 } - operand_ids: 27 + operand_ids: 92 inner_subroutines { - name: "all-reduce_torus-2d_reduce-scatter_dim-1_bidir-ring_ccw_unidir-ring" - subroutine_root_id: 38 + name: "all-reduce_torus-2d_all-gather_stage-0_dim-1_bidir-ring_ccw_unidir-ring" + subroutine_root_id: 98 execution_probability: 1 execution_count: 1 instructions { - name: "all-reduce_torus-2d_reduce-scatter_dim-1_bidir-ring_ccw_unidir-ring_sendrecv_1" + name: "all-reduce_torus-2d_all-gather_stage-0_dim-1_bidir-ring_ccw_unidir-ring_sendrecv_1" opcode: "sendrecv" - instruction_id: 37 + instruction_id: 98 bytes_in: 3 bytes_out: 3 communication_groups { @@ -1221,145 +2292,98 @@ inner_subroutines { group_ids: 0 } } - instructions { - name: "all-reduce_torus-2d_reduce-scatter_dim-1_bidir-ring_ccw_unidir-ring_reduction_1" - opcode: "call" - instruction_id: 38 - operand_ids: 37 - inner_subroutines { - name: "reduction_subroutine_phase_1" - subroutine_root_id: 41 - execution_probability: 1 - execution_count: 1 - instructions { - name: "op1_phase_1" - opcode: "delay" - instruction_id: 39 - bytes_out: 3 - } - instructions { - name: "op2_phase_1" - opcode: "delay" - instruction_id: 40 - bytes_out: 3 - } - instructions { - name: "sum_phase_1" - opcode: "delay" - instruction_id: 41 - ops: 6 - operand_ids: 39 - operand_ids: 40 - } - } - } } } instructions { - name: "all-reduce_torus-2d_reduce-scatter_dim-1_bidir-ring_root_2" + name: "all-reduce_torus-2d_all-gather_stage-0_dim-1_bidir-ring_root_2" opcode: "null" - instruction_id: 42 - operand_ids: 30 - operand_ids: 36 + instruction_id: 99 + operand_ids: 95 + operand_ids: 97 } } } - } - } - instructions { - name: "all-reduce_torus-2d_all-gather" - opcode: "all-gather" - instruction_id: 43 - bytes_out: 48 - communication_groups { - group_ids: 0 - group_ids: 1 - group_ids: 2 - group_ids: 3 - group_ids: 4 - group_ids: 5 - group_ids: 6 - group_ids: 7 - } - operand_ids: 7 - inner_subroutines { - name: "all-reduce_torus-2d_all-gather_torus-2d" - subroutine_root_id: 54 - execution_probability: 1 - execution_count: 1 instructions { - name: "all-reduce_torus-2d_all-gather_dim-0" + name: "all-reduce_torus-2d_all-gather_stage-0_root" + opcode: "null" + instruction_id: 100 + operand_ids: 81 + operand_ids: 91 + } + instructions { + name: "all-reduce_torus-2d_all-gather_stage-1_dim-0" opcode: "all-gather" - instruction_id: 44 - bytes_out: 12 + instruction_id: 101 + bytes_out: 24 communication_groups { group_ids: 2 group_ids: 3 } + operand_ids: 100 inner_subroutines { - name: "all-reduce_torus-2d_all-gather_dim-0_bidir-ring" - subroutine_root_id: 53 + name: "all-reduce_torus-2d_all-gather_stage-1_dim-0_bidir-ring" + subroutine_root_id: 110 execution_probability: 1 execution_count: 1 instructions { - name: "all-reduce_torus-2d_all-gather_dim-0_bidir-ring_barrier" + name: "all-reduce_torus-2d_all-gather_stage-1_dim-0_bidir-ring_barrier" opcode: "barrier" - instruction_id: 45 + instruction_id: 102 communication_groups { group_ids: 2 group_ids: 3 } inner_subroutines { - name: "all-reduce_torus-2d_all-gather_dim-0_bidir-ring_barrier_centralized" - subroutine_root_id: 48 + name: "all-reduce_torus-2d_all-gather_stage-1_dim-0_bidir-ring_barrier_centralized" + subroutine_root_id: 105 execution_probability: 1 execution_count: 1 instructions { - name: "all-reduce_torus-2d_all-gather_dim-0_bidir-ring_barrier_centralized_coordinator_recv_from_3" + name: "all-reduce_torus-2d_all-gather_stage-1_dim-0_bidir-ring_barrier_centralized_coordinator_recv_from_3" opcode: "recv" - instruction_id: 46 + instruction_id: 103 communication_groups { group_ids: 3 } } instructions { - name: "all-reduce_torus-2d_all-gather_dim-0_bidir-ring_barrier_centralized_coordinator_send_to_3" + name: "all-reduce_torus-2d_all-gather_stage-1_dim-0_bidir-ring_barrier_centralized_coordinator_send_to_3" opcode: "send" - instruction_id: 47 + instruction_id: 104 communication_groups { group_ids: 3 } - operand_ids: 46 + operand_ids: 103 } instructions { - name: "all-reduce_torus-2d_all-gather_dim-0_bidir-ring_barrier_centralized_root_2" + name: "all-reduce_torus-2d_all-gather_stage-1_dim-0_bidir-ring_barrier_centralized_root_2" opcode: "null" - instruction_id: 48 - operand_ids: 47 + instruction_id: 105 + operand_ids: 104 } } } instructions { - name: "all-reduce_torus-2d_all-gather_dim-0_bidir-ring_cw" + name: "all-reduce_torus-2d_all-gather_stage-1_dim-0_bidir-ring_cw" opcode: "all-gather" - instruction_id: 49 - bytes_out: 6 + instruction_id: 106 + bytes_out: 12 communication_groups { group_ids: 2 group_ids: 3 } - operand_ids: 45 + operand_ids: 102 inner_subroutines { - name: "all-reduce_torus-2d_all-gather_dim-0_bidir-ring_cw_unidir-ring" - subroutine_root_id: 50 + name: "all-reduce_torus-2d_all-gather_stage-1_dim-0_bidir-ring_cw_unidir-ring" + subroutine_root_id: 107 execution_probability: 1 execution_count: 1 instructions { - name: "all-reduce_torus-2d_all-gather_dim-0_bidir-ring_cw_unidir-ring_sendrecv_1" + name: "all-reduce_torus-2d_all-gather_stage-1_dim-0_bidir-ring_cw_unidir-ring_sendrecv_1" opcode: "sendrecv" - instruction_id: 50 - bytes_in: 3 - bytes_out: 3 + instruction_id: 107 + bytes_in: 6 + bytes_out: 6 communication_groups { group_ids: 3 group_ids: 3 @@ -1368,26 +2392,26 @@ inner_subroutines { } } instructions { - name: "all-reduce_torus-2d_all-gather_dim-0_bidir-ring_ccw" + name: "all-reduce_torus-2d_all-gather_stage-1_dim-0_bidir-ring_ccw" opcode: "all-gather" - instruction_id: 51 - bytes_out: 6 + instruction_id: 108 + bytes_out: 12 communication_groups { group_ids: 3 group_ids: 2 } - operand_ids: 45 + operand_ids: 102 inner_subroutines { - name: "all-reduce_torus-2d_all-gather_dim-0_bidir-ring_ccw_unidir-ring" - subroutine_root_id: 52 + name: "all-reduce_torus-2d_all-gather_stage-1_dim-0_bidir-ring_ccw_unidir-ring" + subroutine_root_id: 109 execution_probability: 1 execution_count: 1 instructions { - name: "all-reduce_torus-2d_all-gather_dim-0_bidir-ring_ccw_unidir-ring_sendrecv_1" + name: "all-reduce_torus-2d_all-gather_stage-1_dim-0_bidir-ring_ccw_unidir-ring_sendrecv_1" opcode: "sendrecv" - instruction_id: 52 - bytes_in: 3 - bytes_out: 3 + instruction_id: 109 + bytes_in: 6 + bytes_out: 6 communication_groups { group_ids: 3 group_ids: 3 @@ -1396,82 +2420,82 @@ inner_subroutines { } } instructions { - name: "all-reduce_torus-2d_all-gather_dim-0_bidir-ring_root_2" + name: "all-reduce_torus-2d_all-gather_stage-1_dim-0_bidir-ring_root_2" opcode: "null" - instruction_id: 53 - operand_ids: 49 - operand_ids: 51 + instruction_id: 110 + operand_ids: 106 + operand_ids: 108 } } } instructions { - name: "all-reduce_torus-2d_all-gather_dim-1" + name: "all-reduce_torus-2d_all-gather_stage-1_dim-1" opcode: "all-gather" - instruction_id: 54 - bytes_out: 12 + instruction_id: 111 + bytes_out: 24 communication_groups { group_ids: 0 group_ids: 2 } - operand_ids: 44 + operand_ids: 100 inner_subroutines { - name: "all-reduce_torus-2d_all-gather_dim-1_bidir-ring" - subroutine_root_id: 62 + name: "all-reduce_torus-2d_all-gather_stage-1_dim-1_bidir-ring" + subroutine_root_id: 119 execution_probability: 1 execution_count: 1 instructions { - name: "all-reduce_torus-2d_all-gather_dim-1_bidir-ring_barrier" + name: "all-reduce_torus-2d_all-gather_stage-1_dim-1_bidir-ring_barrier" opcode: "barrier" - instruction_id: 55 + instruction_id: 112 communication_groups { group_ids: 0 group_ids: 2 } inner_subroutines { - name: "all-reduce_torus-2d_all-gather_dim-1_bidir-ring_barrier_centralized" - subroutine_root_id: 57 + name: "all-reduce_torus-2d_all-gather_stage-1_dim-1_bidir-ring_barrier_centralized" + subroutine_root_id: 114 execution_probability: 1 execution_count: 1 instructions { - name: "all-reduce_torus-2d_all-gather_dim-1_bidir-ring_barrier_centralized_send_to_0" + name: "all-reduce_torus-2d_all-gather_stage-1_dim-1_bidir-ring_barrier_centralized_send_to_0" opcode: "send" - instruction_id: 56 + instruction_id: 113 communication_groups { group_ids: 0 } } instructions { - name: "all-reduce_torus-2d_all-gather_dim-1_bidir-ring_barrier_centralized_recv_from_0" + name: "all-reduce_torus-2d_all-gather_stage-1_dim-1_bidir-ring_barrier_centralized_recv_from_0" opcode: "recv" - instruction_id: 57 + instruction_id: 114 communication_groups { group_ids: 0 } - operand_ids: 56 + operand_ids: 113 } } } instructions { - name: "all-reduce_torus-2d_all-gather_dim-1_bidir-ring_cw" + name: "all-reduce_torus-2d_all-gather_stage-1_dim-1_bidir-ring_cw" opcode: "all-gather" - instruction_id: 58 - bytes_out: 6 + instruction_id: 115 + bytes_out: 12 communication_groups { group_ids: 0 group_ids: 2 } - operand_ids: 55 + operand_ids: 112 inner_subroutines { - name: "all-reduce_torus-2d_all-gather_dim-1_bidir-ring_cw_unidir-ring" - subroutine_root_id: 59 + name: "all-reduce_torus-2d_all-gather_stage-1_dim-1_bidir-ring_cw_unidir-ring" + subroutine_root_id: 116 execution_probability: 1 execution_count: 1 instructions { - name: "all-reduce_torus-2d_all-gather_dim-1_bidir-ring_cw_unidir-ring_sendrecv_1" + name: "all-reduce_torus-2d_all-gather_stage-1_dim-1_bidir-ring_cw_unidir-ring_sendrecv_1" opcode: "sendrecv" - instruction_id: 59 - bytes_in: 3 - bytes_out: 3 + instruction_id: 116 + bytes_in: 6 + bytes_out: 6 communication_groups { group_ids: 0 group_ids: 0 @@ -1480,26 +2504,26 @@ inner_subroutines { } } instructions { - name: "all-reduce_torus-2d_all-gather_dim-1_bidir-ring_ccw" + name: "all-reduce_torus-2d_all-gather_stage-1_dim-1_bidir-ring_ccw" opcode: "all-gather" - instruction_id: 60 - bytes_out: 6 + instruction_id: 117 + bytes_out: 12 communication_groups { group_ids: 2 group_ids: 0 } - operand_ids: 55 + operand_ids: 112 inner_subroutines { - name: "all-reduce_torus-2d_all-gather_dim-1_bidir-ring_ccw_unidir-ring" - subroutine_root_id: 61 + name: "all-reduce_torus-2d_all-gather_stage-1_dim-1_bidir-ring_ccw_unidir-ring" + subroutine_root_id: 118 execution_probability: 1 execution_count: 1 instructions { - name: "all-reduce_torus-2d_all-gather_dim-1_bidir-ring_ccw_unidir-ring_sendrecv_1" + name: "all-reduce_torus-2d_all-gather_stage-1_dim-1_bidir-ring_ccw_unidir-ring_sendrecv_1" opcode: "sendrecv" - instruction_id: 61 - bytes_in: 3 - bytes_out: 3 + instruction_id: 118 + bytes_in: 6 + bytes_out: 6 communication_groups { group_ids: 0 group_ids: 0 @@ -1508,20 +2532,86 @@ inner_subroutines { } } instructions { - name: "all-reduce_torus-2d_all-gather_dim-1_bidir-ring_root_2" + name: "all-reduce_torus-2d_all-gather_stage-1_dim-1_bidir-ring_root_2" opcode: "null" - instruction_id: 62 - operand_ids: 58 - operand_ids: 60 + instruction_id: 119 + operand_ids: 115 + operand_ids: 117 } } } + instructions { + name: "all-reduce_torus-2d_all-gather_stage-1_root" + opcode: "null" + instruction_id: 120 + operand_ids: 101 + operand_ids: 111 + } } } } - )proto"; - google::protobuf::TextFormat::ParseFromString(allreduce_str, - &allreduce_proto); + )proto"; + google::protobuf::TextFormat::ParseFromString(test_str, + &proto); + return proto; +} // NOLINT + +// Tests expanding 2D-Torus all-reduce with barrier +TEST(Torus2dAllReduce, WithBarrier) { + auto graph = absl::make_unique("test_graph", 2); + auto sub = absl::make_unique( + "test_subroutine", graph.get()); + auto sub_ptr = sub.get(); + sub_ptr->SetId(3); + graph->SetEntrySubroutine(std::move(sub)); + + ASSERT_OK_AND_ASSIGN(auto instr_1, paragraph::Instruction::Create( + paragraph::Opcode::kDelay, "first_instruction", sub_ptr)); + instr_1->SetOps(4); + + ASSERT_OK_AND_ASSIGN(auto allreduce, paragraph::Instruction::Create( + paragraph::Opcode::kAllReduce, "all-reduce", sub_ptr)); + allreduce->SetBytesOut(48); + paragraph::CommunicationGroup allreduce_group = {0, 1, 2, 3, 4, 5, 6, 7}; + allreduce->AppendCommunicationGroup(allreduce_group); + + auto reduction_sub = absl::make_unique( + "reduction_subroutine", graph.get()); + auto reduction_ptr = reduction_sub.get(); + ASSERT_OK_AND_ASSIGN(auto op1, paragraph::Instruction::Create( + paragraph::Opcode::kDelay, "op1", reduction_ptr)); + op1->SetBytesOut(48); + ASSERT_OK_AND_ASSIGN(auto op2, paragraph::Instruction::Create( + paragraph::Opcode::kDelay, "op2", reduction_ptr)); + op2->SetBytesOut(48); + ASSERT_OK_AND_ASSIGN(auto sum_op, paragraph::Instruction::Create( + paragraph::Opcode::kDelay, "sum", reduction_ptr, true)); + sum_op->SetOps(96); + sum_op->AddOperand(op1); + sum_op->AddOperand(op2); + allreduce->AppendInnerSubroutine(std::move(reduction_sub)); + + ASSERT_OK_AND_ASSIGN(auto instr_3, paragraph::Instruction::Create( + paragraph::Opcode::kDelay, "last_instruction", sub_ptr, true)); + instr_3->SetOps(4); + + nlohmann::json config = R"( + { + "all-reduce": { + "algorithm": "torus-2d", + "dimension_widths": [2, 2], + "barrier": { + "algorithm": "centralized" + } + } + } + )"_json; + + ASSERT_OK_AND_ASSIGN(auto translators, paragraph::CreateTranslators( + paragraph::TranslatorType::kCollective, config)); + EXPECT_OK(translators["all-reduce"]->Translate(allreduce)); + + paragraph::InstructionProto allreduce_proto = with_barrier_test_proto(); EXPECT_TRUE(google::protobuf::util::MessageDifferencer::Equals( allreduce->ToProto().value(), allreduce_proto)); } diff --git a/paragraph/translation/reducescatter/mesh_2d_reducescatter_translator.cc b/paragraph/translation/reducescatter/mesh_2d_reducescatter_translator.cc index 0c45fe8..0218b6e 100644 --- a/paragraph/translation/reducescatter/mesh_2d_reducescatter_translator.cc +++ b/paragraph/translation/reducescatter/mesh_2d_reducescatter_translator.cc @@ -44,6 +44,12 @@ Mesh2dReduceScatterTranslator::Mesh2dReduceScatterTranslator( if (config.find("concentration") != config.end()) { concentration_ = config["concentration"].get(); } + // conentrated ports + integrated_local_exchange_ = false; + if (config.find("integrated_local_exchange") != config.end()) { + integrated_local_exchange_ = + config["integrated_local_exchange"].get(); + } // Create json config for internal 1D Mesh reduce-scatter nlohmann::json implicit_config = R"( @@ -77,77 +83,92 @@ shim::StatusOr> absl::InvalidArgumentError) << "Processor index points to the wrong Processor ID."; Instruction* previous_instruction = nullptr; - std::vector processor_coordinates; - std::unordered_set whole_world(comm_group.begin(), comm_group.end()); - // Check if we have non-trivial concentration first - if (concentration_ > 1) { - processor_coordinates = ConsecutiveProcessorIdToGridCoordinates( - processor_id, dimension_sizes_, concentration_); - CommunicationGroup comm_group_conc; - for (uint64_t i = 0; i < concentration_; i++) { - processor_coordinates.at(0) = i; - uint64_t new_processor_id = GridCoordinatesToConsecutiveProcessorId( - processor_coordinates, dimension_sizes_, concentration_); - if (whole_world.find(new_processor_id) != whole_world.end()) { - comm_group_conc.push_back(new_processor_id); + CommunicationGroup local_comm_group = CommunicationGroupLocalProjection( + processor_id, comm_group, dimension_sizes_, concentration_); + std::vector stage_comm_sizes; + // We prepare communication sizes for each stage and each dimension as + // dimension and/or communication groups could be uneven + for (size_t dim = 0; dim < dimension_sizes_.size(); dim++) { + if ((comm_size == 0) || (comm_group.empty())) { + stage_comm_sizes.push_back(0); + } else { + stage_comm_sizes.push_back( + comm_size / comm_group.size() / dimension_sizes_.size()); + if (integrated_local_exchange_) { + // Additionally accomodating split of traffic from concentrator among + // dimensions + stage_comm_sizes.at(dim) /= dimension_sizes_.size(); + } + } + } + // We have as many stages as dimensions in the Torus + for (size_t stage = 0; stage < dimension_sizes_.size(); stage++) { + // We run AllGather in parallel for every dimension of Torus + std::vector parallel_reducescatter; + for (size_t dim = 0; dim < dimension_sizes_.size(); dim++) { + auto new_comm_group = CommunicationGroupProjectionOnGrid( + processor_id, comm_group, dim, integrated_local_exchange_, + dimension_sizes_, concentration_); + // Every new stage we should increase communication size + // On the first stage we only exchange data laying in the 1st dimension + // On the second stage we exchange data from both 1st and 2nd dimensions + stage_comm_sizes.at(dim) *= new_comm_group.size(); + // If we don't have any communication in original comm_group within the + // current dimension, just leave it + if (new_comm_group.size() > 1) { + ASSIGN_OR_RETURN(auto reducescatter_stage, Instruction::Create( + Opcode::kReduceScatter, + absl::StrCat(name_prefix, "_stage-", stage, "_dim-", dim), + reducescatter_sub_ptr)); + reducescatter_stage->AppendCommunicationGroup(new_comm_group); + reducescatter_stage->SetBytesOut(stage_comm_sizes.at(dim)); + if (previous_instruction != nullptr) { + reducescatter_stage->AddOperand(previous_instruction); + } + ASSIGN_OR_RETURN(auto reduction_subroutine_stage, + reduction_subroutine->Clone("", /*reset_ids*/ false)); + if ((comm_size != 0) && (stage_comm_sizes.at(dim) != 0)) { + reduction_subroutine_stage->ScalePerformance( + 1.0 * stage_comm_sizes.at(dim) / comm_size); + } + reducescatter_stage->AppendInnerSubroutine(std::move( + reduction_subroutine_stage)); + RETURN_IF_ERROR(reducescatter_translator_->Translate( + reducescatter_stage)); + parallel_reducescatter.push_back(reducescatter_stage); } } - if (comm_group_conc.size() > 1) { + ASSIGN_OR_RETURN(auto reducescatter_root, Instruction::Create( + Opcode::kNull, + absl::StrCat(name_prefix, "_stage-", stage, "_root"), + reducescatter_sub_ptr)); + previous_instruction = reducescatter_root; + for (auto& instr : parallel_reducescatter) { + reducescatter_root->AddOperand(instr); + } + } + // Check if we have non-trivial concentration and need to perform + // explicit local exchange step + if ((concentration_ > 1) && !integrated_local_exchange_) { + if (local_comm_group.size() > 1) { ASSIGN_OR_RETURN(auto reducescatter_conc, Instruction::Create( Opcode::kReduceScatter, - absl::StrCat(name_prefix, - "_dim-conc"), + absl::StrCat(name_prefix, "_conc"), reducescatter_sub_ptr)); - reducescatter_conc->AppendCommunicationGroup(comm_group_conc); - reducescatter_conc->SetBytesOut(comm_size * concentration_ / - comm_group.size()); + reducescatter_conc->AppendCommunicationGroup(local_comm_group); + reducescatter_conc->SetBytesOut(comm_size); + if (previous_instruction != nullptr) { + reducescatter_conc->AddOperand(previous_instruction); + } ASSIGN_OR_RETURN(auto reduction_subroutine_conc, reduction_subroutine->Clone("", /*reset_ids*/ false)); - reduction_subroutine_conc->ScalePerformance(1.0 * concentration_ - / comm_group.size()); reducescatter_conc->AppendInnerSubroutine(std::move( reduction_subroutine_conc)); - RETURN_IF_ERROR(reducescatter_translator_->Translate(reducescatter_conc)); + RETURN_IF_ERROR(reducescatter_translator_->Translate( + reducescatter_conc)); previous_instruction = reducescatter_conc; } } - // Now do the same for every dimension of the mesh - for (size_t dim = 0; dim < dimension_sizes_.size(); dim++) { - processor_coordinates = ConsecutiveProcessorIdToGridCoordinates( - processor_id, dimension_sizes_, concentration_); - CommunicationGroup comm_group_mesh; - uint64_t dim_width = dimension_sizes_.at(dim); - for (uint64_t i = 0; i < dim_width; i++) { - processor_coordinates.at(dim + 1) = i; - uint64_t new_processor_id = GridCoordinatesToConsecutiveProcessorId( - processor_coordinates, dimension_sizes_, concentration_); - if (whole_world.find(new_processor_id) != whole_world.end()) { - comm_group_mesh.push_back(new_processor_id); - } - } - // If we don't have any communication in original comm_group within the - // current dimension, just leave it - if (comm_group_mesh.size() > 1) { - ASSIGN_OR_RETURN(auto reducescatter_mesh, Instruction::Create( - Opcode::kReduceScatter, - absl::StrCat(name_prefix, "_dim-", dim), - reducescatter_sub_ptr)); - reducescatter_mesh->AppendCommunicationGroup(comm_group_mesh); - reducescatter_mesh->SetBytesOut(comm_size * dim_width / - comm_group.size()); - ASSIGN_OR_RETURN(auto reduction_subroutine_mesh, - reduction_subroutine->Clone("", /*reset_ids*/ false)); - reduction_subroutine_mesh->ScalePerformance(1.0 * dim_width - / comm_group.size()); - reducescatter_mesh->AppendInnerSubroutine(std::move( - reduction_subroutine_mesh)); - if (previous_instruction != nullptr) { - reducescatter_mesh->AddOperand(previous_instruction); - } - RETURN_IF_ERROR(reducescatter_translator_->Translate(reducescatter_mesh)); - previous_instruction = reducescatter_mesh; - } - } // Set root instruction for reducescatter subroutine RETURN_IF_ERROR(reducescatter_subroutine->SetRootInstruction( previous_instruction)); diff --git a/paragraph/translation/reducescatter/mesh_2d_reducescatter_translator.h b/paragraph/translation/reducescatter/mesh_2d_reducescatter_translator.h index d298877..a6288d0 100644 --- a/paragraph/translation/reducescatter/mesh_2d_reducescatter_translator.h +++ b/paragraph/translation/reducescatter/mesh_2d_reducescatter_translator.h @@ -60,6 +60,8 @@ class Mesh2dReduceScatterTranslator : public ReduceScatterTranslator { std::vector dimension_sizes_; // Number of processors per mesh node uint64_t concentration_; + // concentrators + bool integrated_local_exchange_; }; } // namespace paragraph diff --git a/paragraph/translation/reducescatter/mesh_2d_reducescatter_translator_test.cc b/paragraph/translation/reducescatter/mesh_2d_reducescatter_translator_test.cc index 3bf8005..1207d12 100644 --- a/paragraph/translation/reducescatter/mesh_2d_reducescatter_translator_test.cc +++ b/paragraph/translation/reducescatter/mesh_2d_reducescatter_translator_test.cc @@ -24,61 +24,9 @@ #include "paragraph/shim/test_macros.h" #include "paragraph/translation/translation_map.h" -// Tests expanding 2D-Mesh reduce-scatter -TEST(Mesh2dReduceScatter, NoBarrier) { - auto graph = absl::make_unique("test_graph", 1); - auto sub = absl::make_unique( - "test_subroutine", graph.get()); - auto sub_ptr = sub.get(); - graph->SetEntrySubroutine(std::move(sub)); - - ASSERT_OK_AND_ASSIGN(auto instr_1, paragraph::Instruction::Create( - paragraph::Opcode::kDelay, "first_instruction", sub_ptr)); - instr_1->SetOps(4); - - ASSERT_OK_AND_ASSIGN(auto reducescatter, - paragraph::Instruction::Create( - paragraph::Opcode::kReduceScatter, "reduce-scatter", sub_ptr)); - reducescatter->SetBytesOut(80); - paragraph::CommunicationGroup reducescatter_group = {0, 1, 2, 3, 4, 5, 6, 7}; - reducescatter->AppendCommunicationGroup(reducescatter_group); - - auto reduction_sub = absl::make_unique( - "reduction_subroutine", graph.get()); - auto reduction_ptr = reduction_sub.get(); - ASSERT_OK_AND_ASSIGN(auto op1, paragraph::Instruction::Create( - paragraph::Opcode::kDelay, "op1", reduction_ptr)); - op1->SetBytesOut(80); - ASSERT_OK_AND_ASSIGN(auto op2, paragraph::Instruction::Create( - paragraph::Opcode::kDelay, "op2", reduction_ptr)); - op2->SetBytesOut(80); - ASSERT_OK_AND_ASSIGN(auto sum_op, paragraph::Instruction::Create( - paragraph::Opcode::kDelay, "sum", reduction_ptr, true)); - sum_op->SetOps(160); - sum_op->AddOperand(op1); - sum_op->AddOperand(op2); - reducescatter->AppendInnerSubroutine(std::move(reduction_sub)); - - ASSERT_OK_AND_ASSIGN(auto instr_3, paragraph::Instruction::Create( - paragraph::Opcode::kDelay, "last_instruction", sub_ptr, true)); - instr_3->SetOps(4); - - nlohmann::json config = R"( - { - "reduce-scatter": { - "algorithm": "mesh-2d", - "concentration": 2, - "dimension_widths": [2, 2] - } - } - )"_json; - - ASSERT_OK_AND_ASSIGN(auto translators, paragraph::CreateTranslators( - paragraph::TranslatorType::kCollective, config)); - EXPECT_OK(translators["reduce-scatter"]->Translate(reducescatter)); - - paragraph::InstructionProto reducescatter_proto; - std::string reducescatter_str = +paragraph::InstructionProto Mesh2dReduceScatter_no_barrier_test_proto() { + paragraph::InstructionProto proto; + std::string test_str = R"proto( name: "reduce-scatter" opcode: "reduce-scatter" @@ -96,25 +44,27 @@ communication_groups { } inner_subroutines { name: "reduce-scatter_mesh-2d" - subroutine_root_id: 19 + subroutine_root_id: 88 execution_probability: 1 execution_count: 1 instructions { - name: "reduce-scatter_dim-conc" + name: "reduce-scatter_stage-0_dim-0" opcode: "reduce-scatter" instruction_id: 7 - bytes_out: 20 + bytes_out: 40 communication_groups { group_ids: 0 group_ids: 1 + group_ids: 2 + group_ids: 3 } inner_subroutines { - name: "reduce-scatter_dim-conc_mesh-1d" - subroutine_root_id: 9 + name: "reduce-scatter_stage-0_dim-0_mesh-1d" + subroutine_root_id: 26 execution_probability: 1 execution_count: 1 instructions { - name: "reduce-scatter_dim-conc_mesh-1d_ccw_sendrecv_0" + name: "reduce-scatter_stage-0_dim-0_mesh-1d_ccw_sendrecv_0" opcode: "sendrecv" instruction_id: 8 bytes_in: 10 @@ -125,7 +75,7 @@ inner_subroutines { } } instructions { - name: "reduce-scatter_dim-conc_mesh-1d_ccw_reduction_0" + name: "reduce-scatter_stage-0_dim-0_mesh-1d_ccw_reduction_0" opcode: "call" instruction_id: 9 operand_ids: 8 @@ -156,143 +106,704 @@ inner_subroutines { } } } + instructions { + name: "reduce-scatter_stage-0_dim-0_mesh-1d_cw_sendrecv_0" + opcode: "sendrecv" + instruction_id: 13 + bytes_in: 10 + bytes_out: 10 + communication_groups { + group_ids: 2 + group_ids: 2 + } + } + instructions { + name: "reduce-scatter_stage-0_dim-0_mesh-1d_cw_reduction_0" + opcode: "call" + instruction_id: 14 + operand_ids: 13 + inner_subroutines { + name: "reduction_subroutine_cw_phase_0" + subroutine_root_id: 17 + execution_probability: 1 + execution_count: 1 + instructions { + name: "op1_cw_phase_0" + opcode: "delay" + instruction_id: 15 + bytes_out: 10 + } + instructions { + name: "op2_cw_phase_0" + opcode: "delay" + instruction_id: 16 + bytes_out: 10 + } + instructions { + name: "sum_cw_phase_0" + opcode: "delay" + instruction_id: 17 + ops: 20 + operand_ids: 15 + operand_ids: 16 + } + } + } + instructions { + name: "reduce-scatter_stage-0_dim-0_mesh-1d_root_0" + opcode: "null" + instruction_id: 18 + operand_ids: 14 + operand_ids: 9 + } + instructions { + name: "reduce-scatter_stage-0_dim-0_mesh-1d_ccw_send_1" + opcode: "send" + instruction_id: 19 + bytes_out: 10 + communication_groups { + group_ids: 0 + } + operand_ids: 18 + } + instructions { + name: "reduce-scatter_stage-0_dim-0_mesh-1d_cw_sendrecv_1" + opcode: "sendrecv" + instruction_id: 20 + bytes_in: 10 + bytes_out: 10 + communication_groups { + group_ids: 2 + group_ids: 2 + } + operand_ids: 18 + } + instructions { + name: "reduce-scatter_stage-0_dim-0_mesh-1d_cw_reduction_1" + opcode: "call" + instruction_id: 21 + operand_ids: 20 + inner_subroutines { + name: "reduction_subroutine_cw_phase_1" + subroutine_root_id: 24 + execution_probability: 1 + execution_count: 1 + instructions { + name: "op1_cw_phase_1" + opcode: "delay" + instruction_id: 22 + bytes_out: 10 + } + instructions { + name: "op2_cw_phase_1" + opcode: "delay" + instruction_id: 23 + bytes_out: 10 + } + instructions { + name: "sum_cw_phase_1" + opcode: "delay" + instruction_id: 24 + ops: 20 + operand_ids: 22 + operand_ids: 23 + } + } + } + instructions { + name: "reduce-scatter_stage-0_dim-0_mesh-1d_root_1" + opcode: "null" + instruction_id: 25 + operand_ids: 21 + operand_ids: 19 + } + instructions { + name: "reduce-scatter_stage-0_dim-0_mesh-1d_ccw_send_2" + opcode: "send" + instruction_id: 26 + bytes_out: 10 + communication_groups { + group_ids: 0 + } + operand_ids: 25 + } } } instructions { - name: "reduce-scatter_dim-0" + name: "reduce-scatter_stage-0_dim-1" opcode: "reduce-scatter" - instruction_id: 13 - bytes_out: 20 + instruction_id: 27 + bytes_out: 40 communication_groups { + group_ids: 0 group_ids: 1 - group_ids: 3 + group_ids: 4 + group_ids: 5 } - operand_ids: 7 inner_subroutines { - name: "reduce-scatter_dim-0_mesh-1d" - subroutine_root_id: 15 + name: "reduce-scatter_stage-0_dim-1_mesh-1d" + subroutine_root_id: 46 execution_probability: 1 execution_count: 1 instructions { - name: "reduce-scatter_dim-0_mesh-1d_cw_sendrecv_0" + name: "reduce-scatter_stage-0_dim-1_mesh-1d_ccw_sendrecv_0" opcode: "sendrecv" - instruction_id: 14 + instruction_id: 28 bytes_in: 10 bytes_out: 10 communication_groups { - group_ids: 3 - group_ids: 3 + group_ids: 0 + group_ids: 0 } } instructions { - name: "reduce-scatter_dim-0_mesh-1d_cw_reduction_0" + name: "reduce-scatter_stage-0_dim-1_mesh-1d_ccw_reduction_0" opcode: "call" - instruction_id: 15 - operand_ids: 14 + instruction_id: 29 + operand_ids: 28 + inner_subroutines { + name: "reduction_subroutine_ccw_phase_0" + subroutine_root_id: 32 + execution_probability: 1 + execution_count: 1 + instructions { + name: "op1_ccw_phase_0" + opcode: "delay" + instruction_id: 30 + bytes_out: 10 + } + instructions { + name: "op2_ccw_phase_0" + opcode: "delay" + instruction_id: 31 + bytes_out: 10 + } + instructions { + name: "sum_ccw_phase_0" + opcode: "delay" + instruction_id: 32 + ops: 20 + operand_ids: 30 + operand_ids: 31 + } + } + } + instructions { + name: "reduce-scatter_stage-0_dim-1_mesh-1d_cw_sendrecv_0" + opcode: "sendrecv" + instruction_id: 33 + bytes_in: 10 + bytes_out: 10 + communication_groups { + group_ids: 4 + group_ids: 4 + } + } + instructions { + name: "reduce-scatter_stage-0_dim-1_mesh-1d_cw_reduction_0" + opcode: "call" + instruction_id: 34 + operand_ids: 33 inner_subroutines { name: "reduction_subroutine_cw_phase_0" - subroutine_root_id: 18 + subroutine_root_id: 37 execution_probability: 1 execution_count: 1 instructions { name: "op1_cw_phase_0" opcode: "delay" - instruction_id: 16 + instruction_id: 35 bytes_out: 10 } instructions { name: "op2_cw_phase_0" opcode: "delay" - instruction_id: 17 + instruction_id: 36 bytes_out: 10 } instructions { name: "sum_cw_phase_0" opcode: "delay" - instruction_id: 18 + instruction_id: 37 ops: 20 - operand_ids: 16 - operand_ids: 17 + operand_ids: 35 + operand_ids: 36 } } } - } - } - instructions { - name: "reduce-scatter_dim-1" - opcode: "reduce-scatter" - instruction_id: 19 - bytes_out: 20 - communication_groups { - group_ids: 1 - group_ids: 5 - } - operand_ids: 13 - inner_subroutines { - name: "reduce-scatter_dim-1_mesh-1d" - subroutine_root_id: 21 - execution_probability: 1 - execution_count: 1 instructions { - name: "reduce-scatter_dim-1_mesh-1d_cw_sendrecv_0" + name: "reduce-scatter_stage-0_dim-1_mesh-1d_root_0" + opcode: "null" + instruction_id: 38 + operand_ids: 34 + operand_ids: 29 + } + instructions { + name: "reduce-scatter_stage-0_dim-1_mesh-1d_ccw_send_1" + opcode: "send" + instruction_id: 39 + bytes_out: 10 + communication_groups { + group_ids: 0 + } + operand_ids: 38 + } + instructions { + name: "reduce-scatter_stage-0_dim-1_mesh-1d_cw_sendrecv_1" + opcode: "sendrecv" + instruction_id: 40 + bytes_in: 10 + bytes_out: 10 + communication_groups { + group_ids: 4 + group_ids: 4 + } + operand_ids: 38 + } + instructions { + name: "reduce-scatter_stage-0_dim-1_mesh-1d_cw_reduction_1" + opcode: "call" + instruction_id: 41 + operand_ids: 40 + inner_subroutines { + name: "reduction_subroutine_cw_phase_1" + subroutine_root_id: 44 + execution_probability: 1 + execution_count: 1 + instructions { + name: "op1_cw_phase_1" + opcode: "delay" + instruction_id: 42 + bytes_out: 10 + } + instructions { + name: "op2_cw_phase_1" + opcode: "delay" + instruction_id: 43 + bytes_out: 10 + } + instructions { + name: "sum_cw_phase_1" + opcode: "delay" + instruction_id: 44 + ops: 20 + operand_ids: 42 + operand_ids: 43 + } + } + } + instructions { + name: "reduce-scatter_stage-0_dim-1_mesh-1d_root_1" + opcode: "null" + instruction_id: 45 + operand_ids: 41 + operand_ids: 39 + } + instructions { + name: "reduce-scatter_stage-0_dim-1_mesh-1d_ccw_send_2" + opcode: "send" + instruction_id: 46 + bytes_out: 10 + communication_groups { + group_ids: 0 + } + operand_ids: 45 + } + } + } + instructions { + name: "reduce-scatter_stage-0_root" + opcode: "null" + instruction_id: 47 + operand_ids: 7 + operand_ids: 27 + } + instructions { + name: "reduce-scatter_stage-1_dim-0" + opcode: "reduce-scatter" + instruction_id: 48 + bytes_out: 80 + communication_groups { + group_ids: 0 + group_ids: 1 + group_ids: 2 + group_ids: 3 + } + operand_ids: 47 + inner_subroutines { + name: "reduce-scatter_stage-1_dim-0_mesh-1d" + subroutine_root_id: 67 + execution_probability: 1 + execution_count: 1 + instructions { + name: "reduce-scatter_stage-1_dim-0_mesh-1d_ccw_sendrecv_0" + opcode: "sendrecv" + instruction_id: 49 + bytes_in: 20 + bytes_out: 20 + communication_groups { + group_ids: 0 + group_ids: 0 + } + } + instructions { + name: "reduce-scatter_stage-1_dim-0_mesh-1d_ccw_reduction_0" + opcode: "call" + instruction_id: 50 + operand_ids: 49 + inner_subroutines { + name: "reduction_subroutine_ccw_phase_0" + subroutine_root_id: 53 + execution_probability: 1 + execution_count: 1 + instructions { + name: "op1_ccw_phase_0" + opcode: "delay" + instruction_id: 51 + bytes_out: 20 + } + instructions { + name: "op2_ccw_phase_0" + opcode: "delay" + instruction_id: 52 + bytes_out: 20 + } + instructions { + name: "sum_ccw_phase_0" + opcode: "delay" + instruction_id: 53 + ops: 40 + operand_ids: 51 + operand_ids: 52 + } + } + } + instructions { + name: "reduce-scatter_stage-1_dim-0_mesh-1d_cw_sendrecv_0" + opcode: "sendrecv" + instruction_id: 54 + bytes_in: 20 + bytes_out: 20 + communication_groups { + group_ids: 2 + group_ids: 2 + } + } + instructions { + name: "reduce-scatter_stage-1_dim-0_mesh-1d_cw_reduction_0" + opcode: "call" + instruction_id: 55 + operand_ids: 54 + inner_subroutines { + name: "reduction_subroutine_cw_phase_0" + subroutine_root_id: 58 + execution_probability: 1 + execution_count: 1 + instructions { + name: "op1_cw_phase_0" + opcode: "delay" + instruction_id: 56 + bytes_out: 20 + } + instructions { + name: "op2_cw_phase_0" + opcode: "delay" + instruction_id: 57 + bytes_out: 20 + } + instructions { + name: "sum_cw_phase_0" + opcode: "delay" + instruction_id: 58 + ops: 40 + operand_ids: 56 + operand_ids: 57 + } + } + } + instructions { + name: "reduce-scatter_stage-1_dim-0_mesh-1d_root_0" + opcode: "null" + instruction_id: 59 + operand_ids: 55 + operand_ids: 50 + } + instructions { + name: "reduce-scatter_stage-1_dim-0_mesh-1d_ccw_send_1" + opcode: "send" + instruction_id: 60 + bytes_out: 20 + communication_groups { + group_ids: 0 + } + operand_ids: 59 + } + instructions { + name: "reduce-scatter_stage-1_dim-0_mesh-1d_cw_sendrecv_1" + opcode: "sendrecv" + instruction_id: 61 + bytes_in: 20 + bytes_out: 20 + communication_groups { + group_ids: 2 + group_ids: 2 + } + operand_ids: 59 + } + instructions { + name: "reduce-scatter_stage-1_dim-0_mesh-1d_cw_reduction_1" + opcode: "call" + instruction_id: 62 + operand_ids: 61 + inner_subroutines { + name: "reduction_subroutine_cw_phase_1" + subroutine_root_id: 65 + execution_probability: 1 + execution_count: 1 + instructions { + name: "op1_cw_phase_1" + opcode: "delay" + instruction_id: 63 + bytes_out: 20 + } + instructions { + name: "op2_cw_phase_1" + opcode: "delay" + instruction_id: 64 + bytes_out: 20 + } + instructions { + name: "sum_cw_phase_1" + opcode: "delay" + instruction_id: 65 + ops: 40 + operand_ids: 63 + operand_ids: 64 + } + } + } + instructions { + name: "reduce-scatter_stage-1_dim-0_mesh-1d_root_1" + opcode: "null" + instruction_id: 66 + operand_ids: 62 + operand_ids: 60 + } + instructions { + name: "reduce-scatter_stage-1_dim-0_mesh-1d_ccw_send_2" + opcode: "send" + instruction_id: 67 + bytes_out: 20 + communication_groups { + group_ids: 0 + } + operand_ids: 66 + } + } + } + instructions { + name: "reduce-scatter_stage-1_dim-1" + opcode: "reduce-scatter" + instruction_id: 68 + bytes_out: 80 + communication_groups { + group_ids: 0 + group_ids: 1 + group_ids: 4 + group_ids: 5 + } + operand_ids: 47 + inner_subroutines { + name: "reduce-scatter_stage-1_dim-1_mesh-1d" + subroutine_root_id: 87 + execution_probability: 1 + execution_count: 1 + instructions { + name: "reduce-scatter_stage-1_dim-1_mesh-1d_ccw_sendrecv_0" + opcode: "sendrecv" + instruction_id: 69 + bytes_in: 20 + bytes_out: 20 + communication_groups { + group_ids: 0 + group_ids: 0 + } + } + instructions { + name: "reduce-scatter_stage-1_dim-1_mesh-1d_ccw_reduction_0" + opcode: "call" + instruction_id: 70 + operand_ids: 69 + inner_subroutines { + name: "reduction_subroutine_ccw_phase_0" + subroutine_root_id: 73 + execution_probability: 1 + execution_count: 1 + instructions { + name: "op1_ccw_phase_0" + opcode: "delay" + instruction_id: 71 + bytes_out: 20 + } + instructions { + name: "op2_ccw_phase_0" + opcode: "delay" + instruction_id: 72 + bytes_out: 20 + } + instructions { + name: "sum_ccw_phase_0" + opcode: "delay" + instruction_id: 73 + ops: 40 + operand_ids: 71 + operand_ids: 72 + } + } + } + instructions { + name: "reduce-scatter_stage-1_dim-1_mesh-1d_cw_sendrecv_0" + opcode: "sendrecv" + instruction_id: 74 + bytes_in: 20 + bytes_out: 20 + communication_groups { + group_ids: 4 + group_ids: 4 + } + } + instructions { + name: "reduce-scatter_stage-1_dim-1_mesh-1d_cw_reduction_0" + opcode: "call" + instruction_id: 75 + operand_ids: 74 + inner_subroutines { + name: "reduction_subroutine_cw_phase_0" + subroutine_root_id: 78 + execution_probability: 1 + execution_count: 1 + instructions { + name: "op1_cw_phase_0" + opcode: "delay" + instruction_id: 76 + bytes_out: 20 + } + instructions { + name: "op2_cw_phase_0" + opcode: "delay" + instruction_id: 77 + bytes_out: 20 + } + instructions { + name: "sum_cw_phase_0" + opcode: "delay" + instruction_id: 78 + ops: 40 + operand_ids: 76 + operand_ids: 77 + } + } + } + instructions { + name: "reduce-scatter_stage-1_dim-1_mesh-1d_root_0" + opcode: "null" + instruction_id: 79 + operand_ids: 75 + operand_ids: 70 + } + instructions { + name: "reduce-scatter_stage-1_dim-1_mesh-1d_ccw_send_1" + opcode: "send" + instruction_id: 80 + bytes_out: 20 + communication_groups { + group_ids: 0 + } + operand_ids: 79 + } + instructions { + name: "reduce-scatter_stage-1_dim-1_mesh-1d_cw_sendrecv_1" opcode: "sendrecv" - instruction_id: 20 - bytes_in: 10 - bytes_out: 10 + instruction_id: 81 + bytes_in: 20 + bytes_out: 20 communication_groups { - group_ids: 5 - group_ids: 5 + group_ids: 4 + group_ids: 4 } + operand_ids: 79 } instructions { - name: "reduce-scatter_dim-1_mesh-1d_cw_reduction_0" + name: "reduce-scatter_stage-1_dim-1_mesh-1d_cw_reduction_1" opcode: "call" - instruction_id: 21 - operand_ids: 20 + instruction_id: 82 + operand_ids: 81 inner_subroutines { - name: "reduction_subroutine_cw_phase_0" - subroutine_root_id: 24 + name: "reduction_subroutine_cw_phase_1" + subroutine_root_id: 85 execution_probability: 1 execution_count: 1 instructions { - name: "op1_cw_phase_0" + name: "op1_cw_phase_1" opcode: "delay" - instruction_id: 22 - bytes_out: 10 + instruction_id: 83 + bytes_out: 20 } instructions { - name: "op2_cw_phase_0" + name: "op2_cw_phase_1" opcode: "delay" - instruction_id: 23 - bytes_out: 10 + instruction_id: 84 + bytes_out: 20 } instructions { - name: "sum_cw_phase_0" + name: "sum_cw_phase_1" opcode: "delay" - instruction_id: 24 - ops: 20 - operand_ids: 22 - operand_ids: 23 + instruction_id: 85 + ops: 40 + operand_ids: 83 + operand_ids: 84 } } } + instructions { + name: "reduce-scatter_stage-1_dim-1_mesh-1d_root_1" + opcode: "null" + instruction_id: 86 + operand_ids: 82 + operand_ids: 80 + } + instructions { + name: "reduce-scatter_stage-1_dim-1_mesh-1d_ccw_send_2" + opcode: "send" + instruction_id: 87 + bytes_out: 20 + communication_groups { + group_ids: 0 + } + operand_ids: 86 + } } } + instructions { + name: "reduce-scatter_stage-1_root" + opcode: "null" + instruction_id: 88 + operand_ids: 48 + operand_ids: 68 + } } )proto"; - google::protobuf::TextFormat::ParseFromString(reducescatter_str, - &reducescatter_proto); - EXPECT_TRUE(google::protobuf::util::MessageDifferencer::Equals( - reducescatter->ToProto().value(), reducescatter_proto)); -} + google::protobuf::TextFormat::ParseFromString(test_str, + &proto); + return proto; +} // NOLINT -// Tests expanding 1D-Mesh reduce-scatter with barrier -TEST(Mesh2dReduceScatter, WithBarrier) { - auto graph = absl::make_unique("test_graph", 2); +// Tests expanding 2D-Mesh reduce-scatter +TEST(Mesh2dReduceScatter, NoBarrier) { + auto graph = absl::make_unique("test_graph", 1); auto sub = absl::make_unique( "test_subroutine", graph.get()); auto sub_ptr = sub.get(); - sub_ptr->SetId(3); graph->SetEntrySubroutine(std::move(sub)); ASSERT_OK_AND_ASSIGN(auto instr_1, paragraph::Instruction::Create( @@ -332,9 +843,7 @@ TEST(Mesh2dReduceScatter, WithBarrier) { "algorithm": "mesh-2d", "concentration": 2, "dimension_widths": [2, 2], - "barrier": { - "algorithm": "centralized" - } + "integrated_local_exchange": true } } )"_json; @@ -343,8 +852,16 @@ TEST(Mesh2dReduceScatter, WithBarrier) { paragraph::TranslatorType::kCollective, config)); EXPECT_OK(translators["reduce-scatter"]->Translate(reducescatter)); - paragraph::InstructionProto reducescatter_proto; - std::string reducescatter_str = + paragraph::InstructionProto reducescatter_proto = + Mesh2dReduceScatter_no_barrier_test_proto(); + EXPECT_TRUE(google::protobuf::util::MessageDifferencer::Equals( + reducescatter->ToProto().value(), reducescatter_proto)); +} + +paragraph::InstructionProto +Mesh2dReduceScatter_with_barrier_test_proto() { + paragraph::InstructionProto proto; + std::string test_str = R"proto( name: "reduce-scatter" opcode: "reduce-scatter" @@ -362,294 +879,499 @@ communication_groups { } inner_subroutines { name: "reduce-scatter_mesh-2d" - subroutine_root_id: 26 + subroutine_root_id: 47 execution_probability: 1 execution_count: 1 instructions { - name: "reduce-scatter_dim-conc" + name: "reduce-scatter_stage-0_dim-0" opcode: "reduce-scatter" instruction_id: 7 bytes_out: 20 communication_groups { + group_ids: 0 group_ids: 2 - group_ids: 3 } inner_subroutines { - name: "reduce-scatter_dim-conc_mesh-1d" - subroutine_root_id: 13 + name: "reduce-scatter_stage-0_dim-0_mesh-1d" + subroutine_root_id: 12 execution_probability: 1 execution_count: 1 instructions { - name: "reduce-scatter_dim-conc_unidir-ring_barrier" + name: "reduce-scatter_stage-0_dim-0_unidir-ring_barrier" opcode: "barrier" instruction_id: 8 communication_groups { + group_ids: 0 group_ids: 2 - group_ids: 3 } inner_subroutines { - name: "reduce-scatter_dim-conc_unidir-ring_barrier_centralized" - subroutine_root_id: 11 + name: "reduce-scatter_stage-0_dim-0_unidir-ring_barrier_centralized" + subroutine_root_id: 10 execution_probability: 1 execution_count: 1 instructions { - name: "reduce-scatter_dim-conc_unidir-ring_barrier_centralized_coordinator_recv_from_3" - opcode: "recv" + name: "reduce-scatter_stage-0_dim-0_unidir-ring_barrier_centralized_send_to_0" + opcode: "send" instruction_id: 9 communication_groups { - group_ids: 3 + group_ids: 0 } } instructions { - name: "reduce-scatter_dim-conc_unidir-ring_barrier_centralized_coordinator_send_to_3" - opcode: "send" + name: "reduce-scatter_stage-0_dim-0_unidir-ring_barrier_centralized_recv_from_0" + opcode: "recv" instruction_id: 10 communication_groups { - group_ids: 3 + group_ids: 0 } operand_ids: 9 } + } + } + instructions { + name: "reduce-scatter_stage-0_dim-0_mesh-1d_ccw_sendrecv_0" + opcode: "sendrecv" + instruction_id: 11 + bytes_in: 10 + bytes_out: 10 + communication_groups { + group_ids: 0 + group_ids: 0 + } + operand_ids: 8 + } + instructions { + name: "reduce-scatter_stage-0_dim-0_mesh-1d_ccw_reduction_0" + opcode: "call" + instruction_id: 12 + operand_ids: 11 + inner_subroutines { + name: "reduction_subroutine_ccw_phase_0" + subroutine_root_id: 15 + execution_probability: 1 + execution_count: 1 + instructions { + name: "op1_ccw_phase_0" + opcode: "delay" + instruction_id: 13 + bytes_out: 10 + } + instructions { + name: "op2_ccw_phase_0" + opcode: "delay" + instruction_id: 14 + bytes_out: 10 + } + instructions { + name: "sum_ccw_phase_0" + opcode: "delay" + instruction_id: 15 + ops: 20 + operand_ids: 13 + operand_ids: 14 + } + } + } + } + } + instructions { + name: "reduce-scatter_stage-0_dim-1" + opcode: "reduce-scatter" + instruction_id: 16 + bytes_out: 20 + communication_groups { + group_ids: 2 + group_ids: 6 + } + inner_subroutines { + name: "reduce-scatter_stage-0_dim-1_mesh-1d" + subroutine_root_id: 22 + execution_probability: 1 + execution_count: 1 + instructions { + name: "reduce-scatter_stage-0_dim-1_unidir-ring_barrier" + opcode: "barrier" + instruction_id: 17 + communication_groups { + group_ids: 2 + group_ids: 6 + } + inner_subroutines { + name: "reduce-scatter_stage-0_dim-1_unidir-ring_barrier_centralized" + subroutine_root_id: 20 + execution_probability: 1 + execution_count: 1 + instructions { + name: "reduce-scatter_stage-0_dim-1_unidir-ring_barrier_centralized_coordinator_recv_from_6" + opcode: "recv" + instruction_id: 18 + communication_groups { + group_ids: 6 + } + } + instructions { + name: "reduce-scatter_stage-0_dim-1_unidir-ring_barrier_centralized_coordinator_send_to_6" + opcode: "send" + instruction_id: 19 + communication_groups { + group_ids: 6 + } + operand_ids: 18 + } instructions { - name: "reduce-scatter_dim-conc_unidir-ring_barrier_centralized_root_2" + name: "reduce-scatter_stage-0_dim-1_unidir-ring_barrier_centralized_root_2" opcode: "null" - instruction_id: 11 - operand_ids: 10 + instruction_id: 20 + operand_ids: 19 } } } instructions { - name: "reduce-scatter_dim-conc_mesh-1d_cw_sendrecv_0" + name: "reduce-scatter_stage-0_dim-1_mesh-1d_cw_sendrecv_0" opcode: "sendrecv" - instruction_id: 12 + instruction_id: 21 bytes_in: 10 bytes_out: 10 communication_groups { - group_ids: 3 - group_ids: 3 + group_ids: 6 + group_ids: 6 } - operand_ids: 8 + operand_ids: 17 } instructions { - name: "reduce-scatter_dim-conc_mesh-1d_cw_reduction_0" + name: "reduce-scatter_stage-0_dim-1_mesh-1d_cw_reduction_0" opcode: "call" - instruction_id: 13 - operand_ids: 12 + instruction_id: 22 + operand_ids: 21 inner_subroutines { name: "reduction_subroutine_cw_phase_0" - subroutine_root_id: 16 + subroutine_root_id: 25 execution_probability: 1 execution_count: 1 instructions { name: "op1_cw_phase_0" opcode: "delay" - instruction_id: 14 + instruction_id: 23 bytes_out: 10 } instructions { name: "op2_cw_phase_0" opcode: "delay" - instruction_id: 15 + instruction_id: 24 bytes_out: 10 } instructions { name: "sum_cw_phase_0" opcode: "delay" - instruction_id: 16 + instruction_id: 25 ops: 20 - operand_ids: 14 - operand_ids: 15 + operand_ids: 23 + operand_ids: 24 } } } } } instructions { - name: "reduce-scatter_dim-0" + name: "reduce-scatter_stage-0_root" + opcode: "null" + instruction_id: 26 + operand_ids: 7 + operand_ids: 16 + } + instructions { + name: "reduce-scatter_stage-1_dim-0" opcode: "reduce-scatter" - instruction_id: 17 - bytes_out: 20 + instruction_id: 27 + bytes_out: 40 communication_groups { group_ids: 0 group_ids: 2 } - operand_ids: 7 + operand_ids: 26 + inner_subroutines { + name: "reduce-scatter_stage-1_dim-0_mesh-1d" + subroutine_root_id: 32 + execution_probability: 1 + execution_count: 1 + instructions { + name: "reduce-scatter_stage-1_dim-0_unidir-ring_barrier" + opcode: "barrier" + instruction_id: 28 + communication_groups { + group_ids: 0 + group_ids: 2 + } + inner_subroutines { + name: "reduce-scatter_stage-1_dim-0_unidir-ring_barrier_centralized" + subroutine_root_id: 30 + execution_probability: 1 + execution_count: 1 + instructions { + name: "reduce-scatter_stage-1_dim-0_unidir-ring_barrier_centralized_send_to_0" + opcode: "send" + instruction_id: 29 + communication_groups { + group_ids: 0 + } + } + instructions { + name: "reduce-scatter_stage-1_dim-0_unidir-ring_barrier_centralized_recv_from_0" + opcode: "recv" + instruction_id: 30 + communication_groups { + group_ids: 0 + } + operand_ids: 29 + } + } + } + instructions { + name: "reduce-scatter_stage-1_dim-0_mesh-1d_ccw_sendrecv_0" + opcode: "sendrecv" + instruction_id: 31 + bytes_in: 20 + bytes_out: 20 + communication_groups { + group_ids: 0 + group_ids: 0 + } + operand_ids: 28 + } + instructions { + name: "reduce-scatter_stage-1_dim-0_mesh-1d_ccw_reduction_0" + opcode: "call" + instruction_id: 32 + operand_ids: 31 + inner_subroutines { + name: "reduction_subroutine_ccw_phase_0" + subroutine_root_id: 35 + execution_probability: 1 + execution_count: 1 + instructions { + name: "op1_ccw_phase_0" + opcode: "delay" + instruction_id: 33 + bytes_out: 20 + } + instructions { + name: "op2_ccw_phase_0" + opcode: "delay" + instruction_id: 34 + bytes_out: 20 + } + instructions { + name: "sum_ccw_phase_0" + opcode: "delay" + instruction_id: 35 + ops: 40 + operand_ids: 33 + operand_ids: 34 + } + } + } + } + } + instructions { + name: "reduce-scatter_stage-1_dim-1" + opcode: "reduce-scatter" + instruction_id: 36 + bytes_out: 40 + communication_groups { + group_ids: 2 + group_ids: 6 + } + operand_ids: 26 inner_subroutines { - name: "reduce-scatter_dim-0_mesh-1d" - subroutine_root_id: 22 + name: "reduce-scatter_stage-1_dim-1_mesh-1d" + subroutine_root_id: 42 execution_probability: 1 execution_count: 1 instructions { - name: "reduce-scatter_dim-0_unidir-ring_barrier" + name: "reduce-scatter_stage-1_dim-1_unidir-ring_barrier" opcode: "barrier" - instruction_id: 18 + instruction_id: 37 communication_groups { - group_ids: 0 group_ids: 2 + group_ids: 6 } inner_subroutines { - name: "reduce-scatter_dim-0_unidir-ring_barrier_centralized" - subroutine_root_id: 20 + name: "reduce-scatter_stage-1_dim-1_unidir-ring_barrier_centralized" + subroutine_root_id: 40 execution_probability: 1 execution_count: 1 instructions { - name: "reduce-scatter_dim-0_unidir-ring_barrier_centralized_send_to_0" - opcode: "send" - instruction_id: 19 + name: "reduce-scatter_stage-1_dim-1_unidir-ring_barrier_centralized_coordinator_recv_from_6" + opcode: "recv" + instruction_id: 38 communication_groups { - group_ids: 0 + group_ids: 6 } } instructions { - name: "reduce-scatter_dim-0_unidir-ring_barrier_centralized_recv_from_0" - opcode: "recv" - instruction_id: 20 + name: "reduce-scatter_stage-1_dim-1_unidir-ring_barrier_centralized_coordinator_send_to_6" + opcode: "send" + instruction_id: 39 communication_groups { - group_ids: 0 + group_ids: 6 } - operand_ids: 19 + operand_ids: 38 + } + instructions { + name: "reduce-scatter_stage-1_dim-1_unidir-ring_barrier_centralized_root_2" + opcode: "null" + instruction_id: 40 + operand_ids: 39 } } } instructions { - name: "reduce-scatter_dim-0_mesh-1d_ccw_sendrecv_0" + name: "reduce-scatter_stage-1_dim-1_mesh-1d_cw_sendrecv_0" opcode: "sendrecv" - instruction_id: 21 - bytes_in: 10 - bytes_out: 10 + instruction_id: 41 + bytes_in: 20 + bytes_out: 20 communication_groups { - group_ids: 0 - group_ids: 0 + group_ids: 6 + group_ids: 6 } - operand_ids: 18 + operand_ids: 37 } instructions { - name: "reduce-scatter_dim-0_mesh-1d_ccw_reduction_0" + name: "reduce-scatter_stage-1_dim-1_mesh-1d_cw_reduction_0" opcode: "call" - instruction_id: 22 - operand_ids: 21 + instruction_id: 42 + operand_ids: 41 inner_subroutines { - name: "reduction_subroutine_ccw_phase_0" - subroutine_root_id: 25 + name: "reduction_subroutine_cw_phase_0" + subroutine_root_id: 45 execution_probability: 1 execution_count: 1 instructions { - name: "op1_ccw_phase_0" + name: "op1_cw_phase_0" opcode: "delay" - instruction_id: 23 - bytes_out: 10 + instruction_id: 43 + bytes_out: 20 } instructions { - name: "op2_ccw_phase_0" + name: "op2_cw_phase_0" opcode: "delay" - instruction_id: 24 - bytes_out: 10 + instruction_id: 44 + bytes_out: 20 } instructions { - name: "sum_ccw_phase_0" + name: "sum_cw_phase_0" opcode: "delay" - instruction_id: 25 - ops: 20 - operand_ids: 23 - operand_ids: 24 + instruction_id: 45 + ops: 40 + operand_ids: 43 + operand_ids: 44 } } } } } instructions { - name: "reduce-scatter_dim-1" + name: "reduce-scatter_stage-1_root" + opcode: "null" + instruction_id: 46 + operand_ids: 27 + operand_ids: 36 + } + instructions { + name: "reduce-scatter_conc" opcode: "reduce-scatter" - instruction_id: 26 - bytes_out: 20 + instruction_id: 47 + bytes_out: 80 communication_groups { group_ids: 2 - group_ids: 6 + group_ids: 3 } - operand_ids: 17 + operand_ids: 46 inner_subroutines { - name: "reduce-scatter_dim-1_mesh-1d" - subroutine_root_id: 32 + name: "reduce-scatter_conc_mesh-1d" + subroutine_root_id: 53 execution_probability: 1 execution_count: 1 instructions { - name: "reduce-scatter_dim-1_unidir-ring_barrier" + name: "reduce-scatter_conc_unidir-ring_barrier" opcode: "barrier" - instruction_id: 27 + instruction_id: 48 communication_groups { group_ids: 2 - group_ids: 6 + group_ids: 3 } inner_subroutines { - name: "reduce-scatter_dim-1_unidir-ring_barrier_centralized" - subroutine_root_id: 30 + name: "reduce-scatter_conc_unidir-ring_barrier_centralized" + subroutine_root_id: 51 execution_probability: 1 execution_count: 1 instructions { - name: "reduce-scatter_dim-1_unidir-ring_barrier_centralized_coordinator_recv_from_6" + name: "reduce-scatter_conc_unidir-ring_barrier_centralized_coordinator_recv_from_3" opcode: "recv" - instruction_id: 28 + instruction_id: 49 communication_groups { - group_ids: 6 + group_ids: 3 } } instructions { - name: "reduce-scatter_dim-1_unidir-ring_barrier_centralized_coordinator_send_to_6" + name: "reduce-scatter_conc_unidir-ring_barrier_centralized_coordinator_send_to_3" opcode: "send" - instruction_id: 29 + instruction_id: 50 communication_groups { - group_ids: 6 + group_ids: 3 } - operand_ids: 28 + operand_ids: 49 } instructions { - name: "reduce-scatter_dim-1_unidir-ring_barrier_centralized_root_2" + name: "reduce-scatter_conc_unidir-ring_barrier_centralized_root_2" opcode: "null" - instruction_id: 30 - operand_ids: 29 + instruction_id: 51 + operand_ids: 50 } } } instructions { - name: "reduce-scatter_dim-1_mesh-1d_cw_sendrecv_0" + name: "reduce-scatter_conc_mesh-1d_cw_sendrecv_0" opcode: "sendrecv" - instruction_id: 31 - bytes_in: 10 - bytes_out: 10 + instruction_id: 52 + bytes_in: 40 + bytes_out: 40 communication_groups { - group_ids: 6 - group_ids: 6 + group_ids: 3 + group_ids: 3 } - operand_ids: 27 + operand_ids: 48 } instructions { - name: "reduce-scatter_dim-1_mesh-1d_cw_reduction_0" + name: "reduce-scatter_conc_mesh-1d_cw_reduction_0" opcode: "call" - instruction_id: 32 - operand_ids: 31 + instruction_id: 53 + operand_ids: 52 inner_subroutines { name: "reduction_subroutine_cw_phase_0" - subroutine_root_id: 35 + subroutine_root_id: 56 execution_probability: 1 execution_count: 1 instructions { name: "op1_cw_phase_0" opcode: "delay" - instruction_id: 33 - bytes_out: 10 + instruction_id: 54 + bytes_out: 40 } instructions { name: "op2_cw_phase_0" opcode: "delay" - instruction_id: 34 - bytes_out: 10 + instruction_id: 55 + bytes_out: 40 } instructions { name: "sum_cw_phase_0" opcode: "delay" - instruction_id: 35 - ops: 20 - operand_ids: 33 - operand_ids: 34 + instruction_id: 56 + ops: 80 + operand_ids: 54 + operand_ids: 55 } } } @@ -657,18 +1379,18 @@ inner_subroutines { } } )proto"; - google::protobuf::TextFormat::ParseFromString(reducescatter_str, - &reducescatter_proto); - EXPECT_TRUE(google::protobuf::util::MessageDifferencer::Equals( - reducescatter->ToProto().value(), reducescatter_proto)); -} + google::protobuf::TextFormat::ParseFromString(test_str, + &proto); + return proto; +} // NOLINT -// Tests expanding 1D-Mesh reduce-scatter -TEST(Mesh2dReduceScatter, InconsecutiveProcessors) { +// Tests expanding 1D-Mesh reduce-scatter with barrier +TEST(Mesh2dReduceScatter, WithBarrier) { auto graph = absl::make_unique("test_graph", 2); auto sub = absl::make_unique( "test_subroutine", graph.get()); auto sub_ptr = sub.get(); + sub_ptr->SetId(3); graph->SetEntrySubroutine(std::move(sub)); ASSERT_OK_AND_ASSIGN(auto instr_1, paragraph::Instruction::Create( @@ -678,8 +1400,8 @@ TEST(Mesh2dReduceScatter, InconsecutiveProcessors) { ASSERT_OK_AND_ASSIGN(auto reducescatter, paragraph::Instruction::Create( paragraph::Opcode::kReduceScatter, "reduce-scatter", sub_ptr)); - reducescatter->SetBytesOut(48); - paragraph::CommunicationGroup reducescatter_group = {0, 2, 4}; + reducescatter->SetBytesOut(80); + paragraph::CommunicationGroup reducescatter_group = {0, 1, 2, 3, 4, 5, 6, 7}; reducescatter->AppendCommunicationGroup(reducescatter_group); auto reduction_sub = absl::make_unique( @@ -687,13 +1409,13 @@ TEST(Mesh2dReduceScatter, InconsecutiveProcessors) { auto reduction_ptr = reduction_sub.get(); ASSERT_OK_AND_ASSIGN(auto op1, paragraph::Instruction::Create( paragraph::Opcode::kDelay, "op1", reduction_ptr)); - op1->SetBytesOut(48); + op1->SetBytesOut(80); ASSERT_OK_AND_ASSIGN(auto op2, paragraph::Instruction::Create( paragraph::Opcode::kDelay, "op2", reduction_ptr)); - op2->SetBytesOut(48); + op2->SetBytesOut(80); ASSERT_OK_AND_ASSIGN(auto sum_op, paragraph::Instruction::Create( paragraph::Opcode::kDelay, "sum", reduction_ptr, true)); - sum_op->SetOps(96); + sum_op->SetOps(160); sum_op->AddOperand(op1); sum_op->AddOperand(op2); reducescatter->AppendInnerSubroutine(std::move(reduction_sub)); @@ -706,7 +1428,11 @@ TEST(Mesh2dReduceScatter, InconsecutiveProcessors) { { "reduce-scatter": { "algorithm": "mesh-2d", - "dimension_widths": [2, 3] + "concentration": 2, + "dimension_widths": [2, 2], + "barrier": { + "algorithm": "centralized" + } } } )"_json; @@ -715,8 +1441,16 @@ TEST(Mesh2dReduceScatter, InconsecutiveProcessors) { paragraph::TranslatorType::kCollective, config)); EXPECT_OK(translators["reduce-scatter"]->Translate(reducescatter)); - paragraph::InstructionProto reducescatter_proto; - std::string reducescatter_str = + paragraph::InstructionProto reducescatter_proto = + Mesh2dReduceScatter_with_barrier_test_proto(); + EXPECT_TRUE(google::protobuf::util::MessageDifferencer::Equals( + reducescatter->ToProto().value(), reducescatter_proto)); +} + +paragraph::InstructionProto +Mesh2dReduceScatter_inconsecutive_proc_test_proto() { + paragraph::InstructionProto proto; + std::string test_str = R"proto( name: "reduce-scatter" opcode: "reduce-scatter" @@ -729,11 +1463,11 @@ communication_groups { } inner_subroutines { name: "reduce-scatter_mesh-2d" - subroutine_root_id: 7 + subroutine_root_id: 38 execution_probability: 1 execution_count: 1 instructions { - name: "reduce-scatter_dim-1" + name: "reduce-scatter_stage-0_dim-1" opcode: "reduce-scatter" instruction_id: 7 bytes_out: 48 @@ -743,12 +1477,12 @@ inner_subroutines { group_ids: 4 } inner_subroutines { - name: "reduce-scatter_dim-1_mesh-1d" + name: "reduce-scatter_stage-0_dim-1_mesh-1d" subroutine_root_id: 21 execution_probability: 1 execution_count: 1 instructions { - name: "reduce-scatter_dim-1_mesh-1d_ccw_sendrecv_0" + name: "reduce-scatter_stage-0_dim-1_mesh-1d_ccw_sendrecv_0" opcode: "sendrecv" instruction_id: 8 bytes_in: 16 @@ -759,7 +1493,7 @@ inner_subroutines { } } instructions { - name: "reduce-scatter_dim-1_mesh-1d_ccw_reduction_0" + name: "reduce-scatter_stage-0_dim-1_mesh-1d_ccw_reduction_0" opcode: "call" instruction_id: 9 operand_ids: 8 @@ -791,7 +1525,7 @@ inner_subroutines { } } instructions { - name: "reduce-scatter_dim-1_mesh-1d_cw_sendrecv_0" + name: "reduce-scatter_stage-0_dim-1_mesh-1d_cw_sendrecv_0" opcode: "sendrecv" instruction_id: 13 bytes_in: 16 @@ -802,7 +1536,7 @@ inner_subroutines { } } instructions { - name: "reduce-scatter_dim-1_mesh-1d_cw_reduction_0" + name: "reduce-scatter_stage-0_dim-1_mesh-1d_cw_reduction_0" opcode: "call" instruction_id: 14 operand_ids: 13 @@ -834,14 +1568,14 @@ inner_subroutines { } } instructions { - name: "reduce-scatter_dim-1_mesh-1d_root_0" + name: "reduce-scatter_stage-0_dim-1_mesh-1d_root_0" opcode: "null" instruction_id: 18 operand_ids: 14 operand_ids: 9 } instructions { - name: "reduce-scatter_dim-1_mesh-1d_ccw_send_1" + name: "reduce-scatter_stage-0_dim-1_mesh-1d_ccw_send_1" opcode: "send" instruction_id: 19 bytes_out: 16 @@ -851,7 +1585,7 @@ inner_subroutines { operand_ids: 18 } instructions { - name: "reduce-scatter_dim-1_mesh-1d_cw_send_1" + name: "reduce-scatter_stage-0_dim-1_mesh-1d_cw_send_1" opcode: "send" instruction_id: 20 bytes_out: 16 @@ -861,7 +1595,7 @@ inner_subroutines { operand_ids: 18 } instructions { - name: "reduce-scatter_dim-1_mesh-1d_root_1" + name: "reduce-scatter_stage-0_dim-1_mesh-1d_root_1" opcode: "null" instruction_id: 21 operand_ids: 20 @@ -869,10 +1603,217 @@ inner_subroutines { } } } + instructions { + name: "reduce-scatter_stage-0_root" + opcode: "null" + instruction_id: 22 + operand_ids: 7 + } + instructions { + name: "reduce-scatter_stage-1_dim-1" + opcode: "reduce-scatter" + instruction_id: 23 + bytes_out: 144 + communication_groups { + group_ids: 0 + group_ids: 2 + group_ids: 4 + } + operand_ids: 22 + inner_subroutines { + name: "reduce-scatter_stage-1_dim-1_mesh-1d" + subroutine_root_id: 37 + execution_probability: 1 + execution_count: 1 + instructions { + name: "reduce-scatter_stage-1_dim-1_mesh-1d_ccw_sendrecv_0" + opcode: "sendrecv" + instruction_id: 24 + bytes_in: 48 + bytes_out: 48 + communication_groups { + group_ids: 0 + group_ids: 0 + } + } + instructions { + name: "reduce-scatter_stage-1_dim-1_mesh-1d_ccw_reduction_0" + opcode: "call" + instruction_id: 25 + operand_ids: 24 + inner_subroutines { + name: "reduction_subroutine_ccw_phase_0" + subroutine_root_id: 28 + execution_probability: 1 + execution_count: 1 + instructions { + name: "op1_ccw_phase_0" + opcode: "delay" + instruction_id: 26 + bytes_out: 48 + } + instructions { + name: "op2_ccw_phase_0" + opcode: "delay" + instruction_id: 27 + bytes_out: 48 + } + instructions { + name: "sum_ccw_phase_0" + opcode: "delay" + instruction_id: 28 + ops: 96 + operand_ids: 26 + operand_ids: 27 + } + } + } + instructions { + name: "reduce-scatter_stage-1_dim-1_mesh-1d_cw_sendrecv_0" + opcode: "sendrecv" + instruction_id: 29 + bytes_in: 48 + bytes_out: 48 + communication_groups { + group_ids: 4 + group_ids: 4 + } + } + instructions { + name: "reduce-scatter_stage-1_dim-1_mesh-1d_cw_reduction_0" + opcode: "call" + instruction_id: 30 + operand_ids: 29 + inner_subroutines { + name: "reduction_subroutine_cw_phase_0" + subroutine_root_id: 33 + execution_probability: 1 + execution_count: 1 + instructions { + name: "op1_cw_phase_0" + opcode: "delay" + instruction_id: 31 + bytes_out: 48 + } + instructions { + name: "op2_cw_phase_0" + opcode: "delay" + instruction_id: 32 + bytes_out: 48 + } + instructions { + name: "sum_cw_phase_0" + opcode: "delay" + instruction_id: 33 + ops: 96 + operand_ids: 31 + operand_ids: 32 + } + } + } + instructions { + name: "reduce-scatter_stage-1_dim-1_mesh-1d_root_0" + opcode: "null" + instruction_id: 34 + operand_ids: 30 + operand_ids: 25 + } + instructions { + name: "reduce-scatter_stage-1_dim-1_mesh-1d_ccw_send_1" + opcode: "send" + instruction_id: 35 + bytes_out: 48 + communication_groups { + group_ids: 0 + } + operand_ids: 34 + } + instructions { + name: "reduce-scatter_stage-1_dim-1_mesh-1d_cw_send_1" + opcode: "send" + instruction_id: 36 + bytes_out: 48 + communication_groups { + group_ids: 4 + } + operand_ids: 34 + } + instructions { + name: "reduce-scatter_stage-1_dim-1_mesh-1d_root_1" + opcode: "null" + instruction_id: 37 + operand_ids: 36 + operand_ids: 35 + } + } + } + instructions { + name: "reduce-scatter_stage-1_root" + opcode: "null" + instruction_id: 38 + operand_ids: 23 + } } )proto"; - google::protobuf::TextFormat::ParseFromString(reducescatter_str, - &reducescatter_proto); + google::protobuf::TextFormat::ParseFromString(test_str, + &proto); + return proto; +} // NOLINT + +// Tests expanding 1D-Mesh reduce-scatter +TEST(Mesh2dReduceScatter, InconsecutiveProcessors) { + auto graph = absl::make_unique("test_graph", 2); + auto sub = absl::make_unique( + "test_subroutine", graph.get()); + auto sub_ptr = sub.get(); + graph->SetEntrySubroutine(std::move(sub)); + + ASSERT_OK_AND_ASSIGN(auto instr_1, paragraph::Instruction::Create( + paragraph::Opcode::kDelay, "first_instruction", sub_ptr)); + instr_1->SetOps(4); + + ASSERT_OK_AND_ASSIGN(auto reducescatter, + paragraph::Instruction::Create( + paragraph::Opcode::kReduceScatter, "reduce-scatter", sub_ptr)); + reducescatter->SetBytesOut(48); + paragraph::CommunicationGroup reducescatter_group = {0, 2, 4}; + reducescatter->AppendCommunicationGroup(reducescatter_group); + + auto reduction_sub = absl::make_unique( + "reduction_subroutine", graph.get()); + auto reduction_ptr = reduction_sub.get(); + ASSERT_OK_AND_ASSIGN(auto op1, paragraph::Instruction::Create( + paragraph::Opcode::kDelay, "op1", reduction_ptr)); + op1->SetBytesOut(48); + ASSERT_OK_AND_ASSIGN(auto op2, paragraph::Instruction::Create( + paragraph::Opcode::kDelay, "op2", reduction_ptr)); + op2->SetBytesOut(48); + ASSERT_OK_AND_ASSIGN(auto sum_op, paragraph::Instruction::Create( + paragraph::Opcode::kDelay, "sum", reduction_ptr, true)); + sum_op->SetOps(96); + sum_op->AddOperand(op1); + sum_op->AddOperand(op2); + reducescatter->AppendInnerSubroutine(std::move(reduction_sub)); + + ASSERT_OK_AND_ASSIGN(auto instr_3, paragraph::Instruction::Create( + paragraph::Opcode::kDelay, "last_instruction", sub_ptr, true)); + instr_3->SetOps(4); + + nlohmann::json config = R"( + { + "reduce-scatter": { + "algorithm": "mesh-2d", + "dimension_widths": [2, 3] + } + } + )"_json; + + ASSERT_OK_AND_ASSIGN(auto translators, paragraph::CreateTranslators( + paragraph::TranslatorType::kCollective, config)); + EXPECT_OK(translators["reduce-scatter"]->Translate(reducescatter)); + + paragraph::InstructionProto reducescatter_proto = + Mesh2dReduceScatter_inconsecutive_proc_test_proto(); EXPECT_TRUE(google::protobuf::util::MessageDifferencer::Equals( reducescatter->ToProto().value(), reducescatter_proto)); } diff --git a/paragraph/translation/reducescatter/torus_2d_reducescatter_translator.cc b/paragraph/translation/reducescatter/torus_2d_reducescatter_translator.cc index 65a6b22..37a35ff 100644 --- a/paragraph/translation/reducescatter/torus_2d_reducescatter_translator.cc +++ b/paragraph/translation/reducescatter/torus_2d_reducescatter_translator.cc @@ -44,6 +44,12 @@ Torus2dReduceScatterTranslator::Torus2dReduceScatterTranslator( if (config.find("concentration") != config.end()) { concentration_ = config["concentration"].get(); } + // conentrated ports + integrated_local_exchange_ = false; + if (config.find("integrated_local_exchange") != config.end()) { + integrated_local_exchange_ = + config["integrated_local_exchange"].get(); + } // Create json config for internal 1D Torus reduce-scatter nlohmann::json implicit_config = R"( @@ -77,76 +83,90 @@ shim::StatusOr> absl::InvalidArgumentError) << "Processor index points to the wrong Processor ID."; Instruction* previous_instruction = nullptr; - std::vector processor_coordinates; - std::unordered_set whole_world(comm_group.begin(), comm_group.end()); - // Check if we have non-trivial concentration first - if (concentration_ > 1) { - processor_coordinates = ConsecutiveProcessorIdToGridCoordinates( - processor_id, dimension_sizes_, concentration_); - CommunicationGroup comm_group_conc; - for (uint64_t i = 0; i < concentration_; i++) { - processor_coordinates.at(0) = i; - uint64_t new_processor_id = GridCoordinatesToConsecutiveProcessorId( - processor_coordinates, dimension_sizes_, concentration_); - if (whole_world.find(new_processor_id) != whole_world.end()) { - comm_group_conc.push_back(new_processor_id); + CommunicationGroup local_comm_group = CommunicationGroupLocalProjection( + processor_id, comm_group, dimension_sizes_, concentration_); + std::vector stage_comm_sizes; + // We prepare communication sizes for each stage and each dimension as + // dimension and/or communication groups could be uneven + for (size_t dim = 0; dim < dimension_sizes_.size(); dim++) { + if ((comm_size == 0) || (comm_group.empty())) { + stage_comm_sizes.push_back(0); + } else { + stage_comm_sizes.push_back( + comm_size / comm_group.size() / dimension_sizes_.size()); + if (integrated_local_exchange_) { + // Additionally accomodating split of traffic from concentrator among + // dimensions + stage_comm_sizes.at(dim) /= dimension_sizes_.size(); } } - if (comm_group_conc.size() > 1) { - ASSIGN_OR_RETURN(auto reducescatter_conc, Instruction::Create( - Opcode::kReduceScatter, - absl::StrCat(name_prefix, - "_dim-conc"), - reducescatter_sub_ptr)); - reducescatter_conc->AppendCommunicationGroup(comm_group_conc); - reducescatter_conc->SetBytesOut(comm_size * concentration_ / - comm_group.size()); - ASSIGN_OR_RETURN(auto reduction_subroutine_conc, - reduction_subroutine->Clone("", /*reset_ids*/ false)); - reduction_subroutine_conc->ScalePerformance(1.0 * concentration_ - / comm_group.size()); - reducescatter_conc->AppendInnerSubroutine(std::move( - reduction_subroutine_conc)); - RETURN_IF_ERROR(reducescatter_translator_->Translate(reducescatter_conc)); - previous_instruction = reducescatter_conc; - } } - // Now do the same for every dimension of the torus - for (size_t dim = 0; dim < dimension_sizes_.size(); dim++) { - processor_coordinates = ConsecutiveProcessorIdToGridCoordinates( - processor_id, dimension_sizes_, concentration_); - CommunicationGroup comm_group_torus; - uint64_t dim_width = dimension_sizes_.at(dim); - for (uint64_t i = 0; i < dim_width; i++) { - processor_coordinates.at(dim + 1) = i; - uint64_t new_processor_id = GridCoordinatesToConsecutiveProcessorId( - processor_coordinates, dimension_sizes_, concentration_); - if (whole_world.find(new_processor_id) != whole_world.end()) { - comm_group_torus.push_back(new_processor_id); + // We have as many stages as dimensions in the Torus + for (size_t stage = 0; stage < dimension_sizes_.size(); stage++) { + // We run AllGather in parallel for every dimension of Torus + std::vector parallel_reducescatter; + for (size_t dim = 0; dim < dimension_sizes_.size(); dim++) { + auto new_comm_group = CommunicationGroupProjectionOnGrid( + processor_id, comm_group, dim, integrated_local_exchange_, + dimension_sizes_, concentration_); + // Every new stage we should increase communication size + // On the first stage we only exchange data laying in the 1st dimension + // On the second stage we exchange data from both 1st and 2nd dimensions + stage_comm_sizes.at(dim) *= new_comm_group.size(); + // If we don't have any communication in original comm_group within the + // current dimension, just leave it + if (new_comm_group.size() > 1) { + ASSIGN_OR_RETURN(auto reducescatter_stage, Instruction::Create( + Opcode::kReduceScatter, + absl::StrCat(name_prefix, "_stage-", stage, "_dim-", dim), + reducescatter_sub_ptr)); + reducescatter_stage->AppendCommunicationGroup(new_comm_group); + reducescatter_stage->SetBytesOut(stage_comm_sizes.at(dim)); + if (previous_instruction != nullptr) { + reducescatter_stage->AddOperand(previous_instruction); + } + ASSIGN_OR_RETURN(auto reduction_subroutine_stage, + reduction_subroutine->Clone("", /*reset_ids*/ false)); + if ((comm_size != 0) && (stage_comm_sizes.at(dim) != 0)) { + reduction_subroutine_stage->ScalePerformance( + 1.0 * stage_comm_sizes.at(dim) / comm_size); + } + reducescatter_stage->AppendInnerSubroutine(std::move( + reduction_subroutine_stage)); + RETURN_IF_ERROR(reducescatter_translator_->Translate( + reducescatter_stage)); + parallel_reducescatter.push_back(reducescatter_stage); } } - // If we don't have any communication in original comm_group within the - // current dimension, just leave it - if (comm_group_torus.size() > 1) { - ASSIGN_OR_RETURN(auto reducescatter_torus, Instruction::Create( + ASSIGN_OR_RETURN(auto reducescatter_root, Instruction::Create( + Opcode::kNull, + absl::StrCat(name_prefix, "_stage-", stage, "_root"), + reducescatter_sub_ptr)); + previous_instruction = reducescatter_root; + for (auto& instr : parallel_reducescatter) { + reducescatter_root->AddOperand(instr); + } + } + // Check if we have non-trivial concentration and need to perform + // explicit local exchange step + if ((concentration_ > 1) && !integrated_local_exchange_) { + if (local_comm_group.size() > 1) { + ASSIGN_OR_RETURN(auto reducescatter_conc, Instruction::Create( Opcode::kReduceScatter, - absl::StrCat(name_prefix, "_dim-", dim), + absl::StrCat(name_prefix, "_conc"), reducescatter_sub_ptr)); - reducescatter_torus->AppendCommunicationGroup(comm_group_torus); - reducescatter_torus->SetBytesOut(comm_size * dim_width / - comm_group.size()); - ASSIGN_OR_RETURN(auto reduction_subroutine_torus, - reduction_subroutine->Clone("", /*reset_ids*/ false)); - reduction_subroutine_torus->ScalePerformance(1.0 * dim_width - / comm_group.size()); - reducescatter_torus->AppendInnerSubroutine(std::move( - reduction_subroutine_torus)); + reducescatter_conc->AppendCommunicationGroup(local_comm_group); + reducescatter_conc->SetBytesOut(comm_size); if (previous_instruction != nullptr) { - reducescatter_torus->AddOperand(previous_instruction); + reducescatter_conc->AddOperand(previous_instruction); } + ASSIGN_OR_RETURN(auto reduction_subroutine_conc, + reduction_subroutine->Clone("", /*reset_ids*/ false)); + reducescatter_conc->AppendInnerSubroutine(std::move( + reduction_subroutine_conc)); RETURN_IF_ERROR(reducescatter_translator_->Translate( - reducescatter_torus)); - previous_instruction = reducescatter_torus; + reducescatter_conc)); + previous_instruction = reducescatter_conc; } } // Set root instruction for reducescatter subroutine diff --git a/paragraph/translation/reducescatter/torus_2d_reducescatter_translator.h b/paragraph/translation/reducescatter/torus_2d_reducescatter_translator.h index 5ddefa7..4962d86 100644 --- a/paragraph/translation/reducescatter/torus_2d_reducescatter_translator.h +++ b/paragraph/translation/reducescatter/torus_2d_reducescatter_translator.h @@ -56,6 +56,8 @@ class Torus2dReduceScatterTranslator : public ReduceScatterTranslator { std::vector dimension_sizes_; // Number of processors per torus node uint64_t concentration_; + // concentrators + bool integrated_local_exchange_; }; } // namespace paragraph diff --git a/paragraph/translation/reducescatter/torus_2d_reducescatter_translator_test.cc b/paragraph/translation/reducescatter/torus_2d_reducescatter_translator_test.cc index 118e63e..649538c 100644 --- a/paragraph/translation/reducescatter/torus_2d_reducescatter_translator_test.cc +++ b/paragraph/translation/reducescatter/torus_2d_reducescatter_translator_test.cc @@ -24,61 +24,9 @@ #include "paragraph/shim/test_macros.h" #include "paragraph/translation/translation_map.h" -// Tests expanding 2D-Torus reduce-scatter -TEST(Torus2dReduceScatter, NoBarrier) { - auto graph = absl::make_unique("test_graph", 1); - auto sub = absl::make_unique( - "test_subroutine", graph.get()); - auto sub_ptr = sub.get(); - graph->SetEntrySubroutine(std::move(sub)); - - ASSERT_OK_AND_ASSIGN(auto instr_1, paragraph::Instruction::Create( - paragraph::Opcode::kDelay, "first_instruction", sub_ptr)); - instr_1->SetOps(4); - - ASSERT_OK_AND_ASSIGN(auto reducescatter, - paragraph::Instruction::Create( - paragraph::Opcode::kReduceScatter, "reduce-scatter", sub_ptr)); - reducescatter->SetBytesOut(80); - paragraph::CommunicationGroup reducescatter_group = {0, 1, 2, 3, 4, 5, 6, 7}; - reducescatter->AppendCommunicationGroup(reducescatter_group); - - auto reduction_sub = absl::make_unique( - "reduction_subroutine", graph.get()); - auto reduction_ptr = reduction_sub.get(); - ASSERT_OK_AND_ASSIGN(auto op1, paragraph::Instruction::Create( - paragraph::Opcode::kDelay, "op1", reduction_ptr)); - op1->SetBytesOut(80); - ASSERT_OK_AND_ASSIGN(auto op2, paragraph::Instruction::Create( - paragraph::Opcode::kDelay, "op2", reduction_ptr)); - op2->SetBytesOut(80); - ASSERT_OK_AND_ASSIGN(auto sum_op, paragraph::Instruction::Create( - paragraph::Opcode::kDelay, "sum", reduction_ptr, true)); - sum_op->SetOps(160); - sum_op->AddOperand(op1); - sum_op->AddOperand(op2); - reducescatter->AppendInnerSubroutine(std::move(reduction_sub)); - - ASSERT_OK_AND_ASSIGN(auto instr_3, paragraph::Instruction::Create( - paragraph::Opcode::kDelay, "last_instruction", sub_ptr, true)); - instr_3->SetOps(4); - - nlohmann::json config = R"( - { - "reduce-scatter": { - "algorithm": "torus-2d", - "concentration": 2, - "dimension_widths": [2, 2] - } - } - )"_json; - - ASSERT_OK_AND_ASSIGN(auto translators, paragraph::CreateTranslators( - paragraph::TranslatorType::kCollective, config)); - EXPECT_OK(translators["reduce-scatter"]->Translate(reducescatter)); - - paragraph::InstructionProto reducescatter_proto; - std::string reducescatter_str = +paragraph::InstructionProto no_barrier_test_proto() { + paragraph::InstructionProto proto; + std::string test_str = R"proto( name: "reduce-scatter" opcode: "reduce-scatter" @@ -96,50 +44,54 @@ communication_groups { } inner_subroutines { name: "reduce-scatter_torus-2d" - subroutine_root_id: 35 + subroutine_root_id: 144 execution_probability: 1 execution_count: 1 instructions { - name: "reduce-scatter_dim-conc" + name: "reduce-scatter_stage-0_dim-0" opcode: "reduce-scatter" instruction_id: 7 - bytes_out: 20 + bytes_out: 40 communication_groups { group_ids: 0 group_ids: 1 + group_ids: 2 + group_ids: 3 } inner_subroutines { - name: "reduce-scatter_dim-conc_bidir-ring" - subroutine_root_id: 20 + name: "reduce-scatter_stage-0_dim-0_bidir-ring" + subroutine_root_id: 40 execution_probability: 1 execution_count: 1 instructions { - name: "reduce-scatter_dim-conc_bidir-ring_cw" + name: "reduce-scatter_stage-0_dim-0_bidir-ring_cw" opcode: "reduce-scatter" instruction_id: 8 - bytes_out: 10 + bytes_out: 20 communication_groups { group_ids: 0 group_ids: 1 + group_ids: 2 + group_ids: 3 } inner_subroutines { - name: "reduce-scatter_dim-conc_bidir-ring_cw_unidir-ring" - subroutine_root_id: 10 + name: "reduce-scatter_stage-0_dim-0_bidir-ring_cw_unidir-ring" + subroutine_root_id: 20 execution_probability: 1 execution_count: 1 instructions { - name: "reduce-scatter_dim-conc_bidir-ring_cw_unidir-ring_sendrecv_1" + name: "reduce-scatter_stage-0_dim-0_bidir-ring_cw_unidir-ring_sendrecv_1" opcode: "sendrecv" instruction_id: 9 bytes_in: 5 bytes_out: 5 communication_groups { group_ids: 0 - group_ids: 0 + group_ids: 2 } } instructions { - name: "reduce-scatter_dim-conc_bidir-ring_cw_unidir-ring_reduction_1" + name: "reduce-scatter_stage-0_dim-0_bidir-ring_cw_unidir-ring_reduction_1" opcode: "call" instruction_id: 10 operand_ids: 9 @@ -170,375 +122,1252 @@ inner_subroutines { } } } - } - } - instructions { - name: "reduce-scatter_dim-conc_bidir-ring_ccw" - opcode: "reduce-scatter" - instruction_id: 14 - bytes_out: 10 - communication_groups { - group_ids: 1 - group_ids: 0 - } - inner_subroutines { - name: "reduce-scatter_dim-conc_bidir-ring_ccw_unidir-ring" - subroutine_root_id: 16 - execution_probability: 1 - execution_count: 1 instructions { - name: "reduce-scatter_dim-conc_bidir-ring_ccw_unidir-ring_sendrecv_1" + name: "reduce-scatter_stage-0_dim-0_bidir-ring_cw_unidir-ring_sendrecv_2" opcode: "sendrecv" - instruction_id: 15 + instruction_id: 14 bytes_in: 5 bytes_out: 5 communication_groups { group_ids: 0 - group_ids: 0 + group_ids: 2 } + operand_ids: 10 } instructions { - name: "reduce-scatter_dim-conc_bidir-ring_ccw_unidir-ring_reduction_1" + name: "reduce-scatter_stage-0_dim-0_bidir-ring_cw_unidir-ring_reduction_2" opcode: "call" - instruction_id: 16 - operand_ids: 15 + instruction_id: 15 + operand_ids: 14 inner_subroutines { - name: "reduction_subroutine_phase_1" - subroutine_root_id: 19 + name: "reduction_subroutine_phase_2" + subroutine_root_id: 18 execution_probability: 1 execution_count: 1 instructions { - name: "op1_phase_1" + name: "op1_phase_2" opcode: "delay" - instruction_id: 17 + instruction_id: 16 bytes_out: 5 } instructions { - name: "op2_phase_1" + name: "op2_phase_2" opcode: "delay" - instruction_id: 18 + instruction_id: 17 bytes_out: 5 } instructions { - name: "sum_phase_1" + name: "sum_phase_2" opcode: "delay" - instruction_id: 19 + instruction_id: 18 ops: 10 + operand_ids: 16 operand_ids: 17 - operand_ids: 18 } } } - } - } - instructions { - name: "reduce-scatter_dim-conc_bidir-ring_root_1" - opcode: "null" - instruction_id: 20 - operand_ids: 8 - operand_ids: 14 - } - } - } - instructions { - name: "reduce-scatter_dim-0" - opcode: "reduce-scatter" - instruction_id: 21 - bytes_out: 20 - communication_groups { - group_ids: 1 - group_ids: 3 - } - operand_ids: 7 - inner_subroutines { - name: "reduce-scatter_dim-0_bidir-ring" - subroutine_root_id: 34 - execution_probability: 1 - execution_count: 1 - instructions { - name: "reduce-scatter_dim-0_bidir-ring_cw" - opcode: "reduce-scatter" - instruction_id: 22 - bytes_out: 10 - communication_groups { - group_ids: 1 - group_ids: 3 - } - inner_subroutines { - name: "reduce-scatter_dim-0_bidir-ring_cw_unidir-ring" - subroutine_root_id: 24 - execution_probability: 1 - execution_count: 1 instructions { - name: "reduce-scatter_dim-0_bidir-ring_cw_unidir-ring_sendrecv_1" + name: "reduce-scatter_stage-0_dim-0_bidir-ring_cw_unidir-ring_sendrecv_3" opcode: "sendrecv" - instruction_id: 23 + instruction_id: 19 bytes_in: 5 bytes_out: 5 communication_groups { - group_ids: 3 - group_ids: 3 + group_ids: 0 + group_ids: 2 } + operand_ids: 15 } instructions { - name: "reduce-scatter_dim-0_bidir-ring_cw_unidir-ring_reduction_1" + name: "reduce-scatter_stage-0_dim-0_bidir-ring_cw_unidir-ring_reduction_3" opcode: "call" - instruction_id: 24 - operand_ids: 23 + instruction_id: 20 + operand_ids: 19 inner_subroutines { - name: "reduction_subroutine_phase_1" - subroutine_root_id: 27 + name: "reduction_subroutine_phase_3" + subroutine_root_id: 23 execution_probability: 1 execution_count: 1 instructions { - name: "op1_phase_1" + name: "op1_phase_3" opcode: "delay" - instruction_id: 25 + instruction_id: 21 bytes_out: 5 } instructions { - name: "op2_phase_1" + name: "op2_phase_3" opcode: "delay" - instruction_id: 26 + instruction_id: 22 bytes_out: 5 } instructions { - name: "sum_phase_1" + name: "sum_phase_3" opcode: "delay" - instruction_id: 27 + instruction_id: 23 ops: 10 - operand_ids: 25 - operand_ids: 26 + operand_ids: 21 + operand_ids: 22 } } } } } instructions { - name: "reduce-scatter_dim-0_bidir-ring_ccw" + name: "reduce-scatter_stage-0_dim-0_bidir-ring_ccw" opcode: "reduce-scatter" - instruction_id: 28 - bytes_out: 10 + instruction_id: 24 + bytes_out: 20 communication_groups { group_ids: 3 + group_ids: 2 group_ids: 1 + group_ids: 0 } inner_subroutines { - name: "reduce-scatter_dim-0_bidir-ring_ccw_unidir-ring" - subroutine_root_id: 30 + name: "reduce-scatter_stage-0_dim-0_bidir-ring_ccw_unidir-ring" + subroutine_root_id: 36 execution_probability: 1 execution_count: 1 instructions { - name: "reduce-scatter_dim-0_bidir-ring_ccw_unidir-ring_sendrecv_1" + name: "reduce-scatter_stage-0_dim-0_bidir-ring_ccw_unidir-ring_sendrecv_1" opcode: "sendrecv" - instruction_id: 29 + instruction_id: 25 bytes_in: 5 bytes_out: 5 communication_groups { - group_ids: 3 - group_ids: 3 + group_ids: 2 + group_ids: 0 } } instructions { - name: "reduce-scatter_dim-0_bidir-ring_ccw_unidir-ring_reduction_1" + name: "reduce-scatter_stage-0_dim-0_bidir-ring_ccw_unidir-ring_reduction_1" opcode: "call" - instruction_id: 30 - operand_ids: 29 + instruction_id: 26 + operand_ids: 25 inner_subroutines { name: "reduction_subroutine_phase_1" - subroutine_root_id: 33 + subroutine_root_id: 29 execution_probability: 1 execution_count: 1 instructions { name: "op1_phase_1" opcode: "delay" - instruction_id: 31 + instruction_id: 27 bytes_out: 5 } instructions { name: "op2_phase_1" opcode: "delay" - instruction_id: 32 + instruction_id: 28 bytes_out: 5 } instructions { name: "sum_phase_1" opcode: "delay" - instruction_id: 33 + instruction_id: 29 ops: 10 - operand_ids: 31 - operand_ids: 32 + operand_ids: 27 + operand_ids: 28 } } } - } - } - instructions { - name: "reduce-scatter_dim-0_bidir-ring_root_1" - opcode: "null" - instruction_id: 34 - operand_ids: 22 - operand_ids: 28 - } - } - } - instructions { - name: "reduce-scatter_dim-1" - opcode: "reduce-scatter" - instruction_id: 35 - bytes_out: 20 - communication_groups { - group_ids: 1 - group_ids: 5 - } - operand_ids: 21 - inner_subroutines { - name: "reduce-scatter_dim-1_bidir-ring" - subroutine_root_id: 48 - execution_probability: 1 - execution_count: 1 - instructions { - name: "reduce-scatter_dim-1_bidir-ring_cw" - opcode: "reduce-scatter" - instruction_id: 36 - bytes_out: 10 - communication_groups { - group_ids: 1 - group_ids: 5 - } - inner_subroutines { - name: "reduce-scatter_dim-1_bidir-ring_cw_unidir-ring" - subroutine_root_id: 38 - execution_probability: 1 - execution_count: 1 instructions { - name: "reduce-scatter_dim-1_bidir-ring_cw_unidir-ring_sendrecv_1" + name: "reduce-scatter_stage-0_dim-0_bidir-ring_ccw_unidir-ring_sendrecv_2" opcode: "sendrecv" - instruction_id: 37 + instruction_id: 30 bytes_in: 5 bytes_out: 5 communication_groups { - group_ids: 5 - group_ids: 5 + group_ids: 2 + group_ids: 0 } + operand_ids: 26 } instructions { - name: "reduce-scatter_dim-1_bidir-ring_cw_unidir-ring_reduction_1" + name: "reduce-scatter_stage-0_dim-0_bidir-ring_ccw_unidir-ring_reduction_2" opcode: "call" - instruction_id: 38 - operand_ids: 37 + instruction_id: 31 + operand_ids: 30 inner_subroutines { - name: "reduction_subroutine_phase_1" - subroutine_root_id: 41 + name: "reduction_subroutine_phase_2" + subroutine_root_id: 34 execution_probability: 1 execution_count: 1 instructions { - name: "op1_phase_1" + name: "op1_phase_2" opcode: "delay" - instruction_id: 39 + instruction_id: 32 bytes_out: 5 } instructions { - name: "op2_phase_1" + name: "op2_phase_2" opcode: "delay" - instruction_id: 40 + instruction_id: 33 bytes_out: 5 } instructions { - name: "sum_phase_1" + name: "sum_phase_2" opcode: "delay" - instruction_id: 41 + instruction_id: 34 ops: 10 - operand_ids: 39 - operand_ids: 40 + operand_ids: 32 + operand_ids: 33 } } } - } - } - instructions { - name: "reduce-scatter_dim-1_bidir-ring_ccw" - opcode: "reduce-scatter" - instruction_id: 42 - bytes_out: 10 - communication_groups { - group_ids: 5 - group_ids: 1 - } - inner_subroutines { - name: "reduce-scatter_dim-1_bidir-ring_ccw_unidir-ring" - subroutine_root_id: 44 - execution_probability: 1 - execution_count: 1 instructions { - name: "reduce-scatter_dim-1_bidir-ring_ccw_unidir-ring_sendrecv_1" + name: "reduce-scatter_stage-0_dim-0_bidir-ring_ccw_unidir-ring_sendrecv_3" opcode: "sendrecv" - instruction_id: 43 + instruction_id: 35 bytes_in: 5 bytes_out: 5 communication_groups { - group_ids: 5 - group_ids: 5 + group_ids: 2 + group_ids: 0 } + operand_ids: 31 } instructions { - name: "reduce-scatter_dim-1_bidir-ring_ccw_unidir-ring_reduction_1" + name: "reduce-scatter_stage-0_dim-0_bidir-ring_ccw_unidir-ring_reduction_3" opcode: "call" - instruction_id: 44 - operand_ids: 43 + instruction_id: 36 + operand_ids: 35 inner_subroutines { - name: "reduction_subroutine_phase_1" - subroutine_root_id: 47 + name: "reduction_subroutine_phase_3" + subroutine_root_id: 39 execution_probability: 1 execution_count: 1 instructions { - name: "op1_phase_1" + name: "op1_phase_3" opcode: "delay" - instruction_id: 45 + instruction_id: 37 bytes_out: 5 } instructions { - name: "op2_phase_1" + name: "op2_phase_3" opcode: "delay" - instruction_id: 46 + instruction_id: 38 bytes_out: 5 } instructions { - name: "sum_phase_1" + name: "sum_phase_3" opcode: "delay" - instruction_id: 47 + instruction_id: 39 ops: 10 - operand_ids: 45 - operand_ids: 46 + operand_ids: 37 + operand_ids: 38 } } } } } instructions { - name: "reduce-scatter_dim-1_bidir-ring_root_1" + name: "reduce-scatter_stage-0_dim-0_bidir-ring_root_1" opcode: "null" - instruction_id: 48 - operand_ids: 36 - operand_ids: 42 + instruction_id: 40 + operand_ids: 8 + operand_ids: 24 } } } -} - )proto"; - google::protobuf::TextFormat::ParseFromString(reducescatter_str, - &reducescatter_proto); - EXPECT_TRUE(google::protobuf::util::MessageDifferencer::Equals( - reducescatter->ToProto().value(), reducescatter_proto)); -} - -// Tests expanding 1D-Torus reduce-scatter with barrier -TEST(Torus2dReduceScatter, WithBarrier) { - auto graph = absl::make_unique("test_graph", 2); - auto sub = absl::make_unique( - "test_subroutine", graph.get()); - auto sub_ptr = sub.get(); - sub_ptr->SetId(3); + instructions { + name: "reduce-scatter_stage-0_dim-1" + opcode: "reduce-scatter" + instruction_id: 41 + bytes_out: 40 + communication_groups { + group_ids: 0 + group_ids: 1 + group_ids: 4 + group_ids: 5 + } + inner_subroutines { + name: "reduce-scatter_stage-0_dim-1_bidir-ring" + subroutine_root_id: 74 + execution_probability: 1 + execution_count: 1 + instructions { + name: "reduce-scatter_stage-0_dim-1_bidir-ring_cw" + opcode: "reduce-scatter" + instruction_id: 42 + bytes_out: 20 + communication_groups { + group_ids: 0 + group_ids: 1 + group_ids: 4 + group_ids: 5 + } + inner_subroutines { + name: "reduce-scatter_stage-0_dim-1_bidir-ring_cw_unidir-ring" + subroutine_root_id: 54 + execution_probability: 1 + execution_count: 1 + instructions { + name: "reduce-scatter_stage-0_dim-1_bidir-ring_cw_unidir-ring_sendrecv_1" + opcode: "sendrecv" + instruction_id: 43 + bytes_in: 5 + bytes_out: 5 + communication_groups { + group_ids: 0 + group_ids: 4 + } + } + instructions { + name: "reduce-scatter_stage-0_dim-1_bidir-ring_cw_unidir-ring_reduction_1" + opcode: "call" + instruction_id: 44 + operand_ids: 43 + inner_subroutines { + name: "reduction_subroutine_phase_1" + subroutine_root_id: 47 + execution_probability: 1 + execution_count: 1 + instructions { + name: "op1_phase_1" + opcode: "delay" + instruction_id: 45 + bytes_out: 5 + } + instructions { + name: "op2_phase_1" + opcode: "delay" + instruction_id: 46 + bytes_out: 5 + } + instructions { + name: "sum_phase_1" + opcode: "delay" + instruction_id: 47 + ops: 10 + operand_ids: 45 + operand_ids: 46 + } + } + } + instructions { + name: "reduce-scatter_stage-0_dim-1_bidir-ring_cw_unidir-ring_sendrecv_2" + opcode: "sendrecv" + instruction_id: 48 + bytes_in: 5 + bytes_out: 5 + communication_groups { + group_ids: 0 + group_ids: 4 + } + operand_ids: 44 + } + instructions { + name: "reduce-scatter_stage-0_dim-1_bidir-ring_cw_unidir-ring_reduction_2" + opcode: "call" + instruction_id: 49 + operand_ids: 48 + inner_subroutines { + name: "reduction_subroutine_phase_2" + subroutine_root_id: 52 + execution_probability: 1 + execution_count: 1 + instructions { + name: "op1_phase_2" + opcode: "delay" + instruction_id: 50 + bytes_out: 5 + } + instructions { + name: "op2_phase_2" + opcode: "delay" + instruction_id: 51 + bytes_out: 5 + } + instructions { + name: "sum_phase_2" + opcode: "delay" + instruction_id: 52 + ops: 10 + operand_ids: 50 + operand_ids: 51 + } + } + } + instructions { + name: "reduce-scatter_stage-0_dim-1_bidir-ring_cw_unidir-ring_sendrecv_3" + opcode: "sendrecv" + instruction_id: 53 + bytes_in: 5 + bytes_out: 5 + communication_groups { + group_ids: 0 + group_ids: 4 + } + operand_ids: 49 + } + instructions { + name: "reduce-scatter_stage-0_dim-1_bidir-ring_cw_unidir-ring_reduction_3" + opcode: "call" + instruction_id: 54 + operand_ids: 53 + inner_subroutines { + name: "reduction_subroutine_phase_3" + subroutine_root_id: 57 + execution_probability: 1 + execution_count: 1 + instructions { + name: "op1_phase_3" + opcode: "delay" + instruction_id: 55 + bytes_out: 5 + } + instructions { + name: "op2_phase_3" + opcode: "delay" + instruction_id: 56 + bytes_out: 5 + } + instructions { + name: "sum_phase_3" + opcode: "delay" + instruction_id: 57 + ops: 10 + operand_ids: 55 + operand_ids: 56 + } + } + } + } + } + instructions { + name: "reduce-scatter_stage-0_dim-1_bidir-ring_ccw" + opcode: "reduce-scatter" + instruction_id: 58 + bytes_out: 20 + communication_groups { + group_ids: 5 + group_ids: 4 + group_ids: 1 + group_ids: 0 + } + inner_subroutines { + name: "reduce-scatter_stage-0_dim-1_bidir-ring_ccw_unidir-ring" + subroutine_root_id: 70 + execution_probability: 1 + execution_count: 1 + instructions { + name: "reduce-scatter_stage-0_dim-1_bidir-ring_ccw_unidir-ring_sendrecv_1" + opcode: "sendrecv" + instruction_id: 59 + bytes_in: 5 + bytes_out: 5 + communication_groups { + group_ids: 4 + group_ids: 0 + } + } + instructions { + name: "reduce-scatter_stage-0_dim-1_bidir-ring_ccw_unidir-ring_reduction_1" + opcode: "call" + instruction_id: 60 + operand_ids: 59 + inner_subroutines { + name: "reduction_subroutine_phase_1" + subroutine_root_id: 63 + execution_probability: 1 + execution_count: 1 + instructions { + name: "op1_phase_1" + opcode: "delay" + instruction_id: 61 + bytes_out: 5 + } + instructions { + name: "op2_phase_1" + opcode: "delay" + instruction_id: 62 + bytes_out: 5 + } + instructions { + name: "sum_phase_1" + opcode: "delay" + instruction_id: 63 + ops: 10 + operand_ids: 61 + operand_ids: 62 + } + } + } + instructions { + name: "reduce-scatter_stage-0_dim-1_bidir-ring_ccw_unidir-ring_sendrecv_2" + opcode: "sendrecv" + instruction_id: 64 + bytes_in: 5 + bytes_out: 5 + communication_groups { + group_ids: 4 + group_ids: 0 + } + operand_ids: 60 + } + instructions { + name: "reduce-scatter_stage-0_dim-1_bidir-ring_ccw_unidir-ring_reduction_2" + opcode: "call" + instruction_id: 65 + operand_ids: 64 + inner_subroutines { + name: "reduction_subroutine_phase_2" + subroutine_root_id: 68 + execution_probability: 1 + execution_count: 1 + instructions { + name: "op1_phase_2" + opcode: "delay" + instruction_id: 66 + bytes_out: 5 + } + instructions { + name: "op2_phase_2" + opcode: "delay" + instruction_id: 67 + bytes_out: 5 + } + instructions { + name: "sum_phase_2" + opcode: "delay" + instruction_id: 68 + ops: 10 + operand_ids: 66 + operand_ids: 67 + } + } + } + instructions { + name: "reduce-scatter_stage-0_dim-1_bidir-ring_ccw_unidir-ring_sendrecv_3" + opcode: "sendrecv" + instruction_id: 69 + bytes_in: 5 + bytes_out: 5 + communication_groups { + group_ids: 4 + group_ids: 0 + } + operand_ids: 65 + } + instructions { + name: "reduce-scatter_stage-0_dim-1_bidir-ring_ccw_unidir-ring_reduction_3" + opcode: "call" + instruction_id: 70 + operand_ids: 69 + inner_subroutines { + name: "reduction_subroutine_phase_3" + subroutine_root_id: 73 + execution_probability: 1 + execution_count: 1 + instructions { + name: "op1_phase_3" + opcode: "delay" + instruction_id: 71 + bytes_out: 5 + } + instructions { + name: "op2_phase_3" + opcode: "delay" + instruction_id: 72 + bytes_out: 5 + } + instructions { + name: "sum_phase_3" + opcode: "delay" + instruction_id: 73 + ops: 10 + operand_ids: 71 + operand_ids: 72 + } + } + } + } + } + instructions { + name: "reduce-scatter_stage-0_dim-1_bidir-ring_root_1" + opcode: "null" + instruction_id: 74 + operand_ids: 42 + operand_ids: 58 + } + } + } + instructions { + name: "reduce-scatter_stage-0_root" + opcode: "null" + instruction_id: 75 + operand_ids: 7 + operand_ids: 41 + } + instructions { + name: "reduce-scatter_stage-1_dim-0" + opcode: "reduce-scatter" + instruction_id: 76 + bytes_out: 80 + communication_groups { + group_ids: 0 + group_ids: 1 + group_ids: 2 + group_ids: 3 + } + operand_ids: 75 + inner_subroutines { + name: "reduce-scatter_stage-1_dim-0_bidir-ring" + subroutine_root_id: 109 + execution_probability: 1 + execution_count: 1 + instructions { + name: "reduce-scatter_stage-1_dim-0_bidir-ring_cw" + opcode: "reduce-scatter" + instruction_id: 77 + bytes_out: 40 + communication_groups { + group_ids: 0 + group_ids: 1 + group_ids: 2 + group_ids: 3 + } + inner_subroutines { + name: "reduce-scatter_stage-1_dim-0_bidir-ring_cw_unidir-ring" + subroutine_root_id: 89 + execution_probability: 1 + execution_count: 1 + instructions { + name: "reduce-scatter_stage-1_dim-0_bidir-ring_cw_unidir-ring_sendrecv_1" + opcode: "sendrecv" + instruction_id: 78 + bytes_in: 10 + bytes_out: 10 + communication_groups { + group_ids: 0 + group_ids: 2 + } + } + instructions { + name: "reduce-scatter_stage-1_dim-0_bidir-ring_cw_unidir-ring_reduction_1" + opcode: "call" + instruction_id: 79 + operand_ids: 78 + inner_subroutines { + name: "reduction_subroutine_phase_1" + subroutine_root_id: 82 + execution_probability: 1 + execution_count: 1 + instructions { + name: "op1_phase_1" + opcode: "delay" + instruction_id: 80 + bytes_out: 10 + } + instructions { + name: "op2_phase_1" + opcode: "delay" + instruction_id: 81 + bytes_out: 10 + } + instructions { + name: "sum_phase_1" + opcode: "delay" + instruction_id: 82 + ops: 20 + operand_ids: 80 + operand_ids: 81 + } + } + } + instructions { + name: "reduce-scatter_stage-1_dim-0_bidir-ring_cw_unidir-ring_sendrecv_2" + opcode: "sendrecv" + instruction_id: 83 + bytes_in: 10 + bytes_out: 10 + communication_groups { + group_ids: 0 + group_ids: 2 + } + operand_ids: 79 + } + instructions { + name: "reduce-scatter_stage-1_dim-0_bidir-ring_cw_unidir-ring_reduction_2" + opcode: "call" + instruction_id: 84 + operand_ids: 83 + inner_subroutines { + name: "reduction_subroutine_phase_2" + subroutine_root_id: 87 + execution_probability: 1 + execution_count: 1 + instructions { + name: "op1_phase_2" + opcode: "delay" + instruction_id: 85 + bytes_out: 10 + } + instructions { + name: "op2_phase_2" + opcode: "delay" + instruction_id: 86 + bytes_out: 10 + } + instructions { + name: "sum_phase_2" + opcode: "delay" + instruction_id: 87 + ops: 20 + operand_ids: 85 + operand_ids: 86 + } + } + } + instructions { + name: "reduce-scatter_stage-1_dim-0_bidir-ring_cw_unidir-ring_sendrecv_3" + opcode: "sendrecv" + instruction_id: 88 + bytes_in: 10 + bytes_out: 10 + communication_groups { + group_ids: 0 + group_ids: 2 + } + operand_ids: 84 + } + instructions { + name: "reduce-scatter_stage-1_dim-0_bidir-ring_cw_unidir-ring_reduction_3" + opcode: "call" + instruction_id: 89 + operand_ids: 88 + inner_subroutines { + name: "reduction_subroutine_phase_3" + subroutine_root_id: 92 + execution_probability: 1 + execution_count: 1 + instructions { + name: "op1_phase_3" + opcode: "delay" + instruction_id: 90 + bytes_out: 10 + } + instructions { + name: "op2_phase_3" + opcode: "delay" + instruction_id: 91 + bytes_out: 10 + } + instructions { + name: "sum_phase_3" + opcode: "delay" + instruction_id: 92 + ops: 20 + operand_ids: 90 + operand_ids: 91 + } + } + } + } + } + instructions { + name: "reduce-scatter_stage-1_dim-0_bidir-ring_ccw" + opcode: "reduce-scatter" + instruction_id: 93 + bytes_out: 40 + communication_groups { + group_ids: 3 + group_ids: 2 + group_ids: 1 + group_ids: 0 + } + inner_subroutines { + name: "reduce-scatter_stage-1_dim-0_bidir-ring_ccw_unidir-ring" + subroutine_root_id: 105 + execution_probability: 1 + execution_count: 1 + instructions { + name: "reduce-scatter_stage-1_dim-0_bidir-ring_ccw_unidir-ring_sendrecv_1" + opcode: "sendrecv" + instruction_id: 94 + bytes_in: 10 + bytes_out: 10 + communication_groups { + group_ids: 2 + group_ids: 0 + } + } + instructions { + name: "reduce-scatter_stage-1_dim-0_bidir-ring_ccw_unidir-ring_reduction_1" + opcode: "call" + instruction_id: 95 + operand_ids: 94 + inner_subroutines { + name: "reduction_subroutine_phase_1" + subroutine_root_id: 98 + execution_probability: 1 + execution_count: 1 + instructions { + name: "op1_phase_1" + opcode: "delay" + instruction_id: 96 + bytes_out: 10 + } + instructions { + name: "op2_phase_1" + opcode: "delay" + instruction_id: 97 + bytes_out: 10 + } + instructions { + name: "sum_phase_1" + opcode: "delay" + instruction_id: 98 + ops: 20 + operand_ids: 96 + operand_ids: 97 + } + } + } + instructions { + name: "reduce-scatter_stage-1_dim-0_bidir-ring_ccw_unidir-ring_sendrecv_2" + opcode: "sendrecv" + instruction_id: 99 + bytes_in: 10 + bytes_out: 10 + communication_groups { + group_ids: 2 + group_ids: 0 + } + operand_ids: 95 + } + instructions { + name: "reduce-scatter_stage-1_dim-0_bidir-ring_ccw_unidir-ring_reduction_2" + opcode: "call" + instruction_id: 100 + operand_ids: 99 + inner_subroutines { + name: "reduction_subroutine_phase_2" + subroutine_root_id: 103 + execution_probability: 1 + execution_count: 1 + instructions { + name: "op1_phase_2" + opcode: "delay" + instruction_id: 101 + bytes_out: 10 + } + instructions { + name: "op2_phase_2" + opcode: "delay" + instruction_id: 102 + bytes_out: 10 + } + instructions { + name: "sum_phase_2" + opcode: "delay" + instruction_id: 103 + ops: 20 + operand_ids: 101 + operand_ids: 102 + } + } + } + instructions { + name: "reduce-scatter_stage-1_dim-0_bidir-ring_ccw_unidir-ring_sendrecv_3" + opcode: "sendrecv" + instruction_id: 104 + bytes_in: 10 + bytes_out: 10 + communication_groups { + group_ids: 2 + group_ids: 0 + } + operand_ids: 100 + } + instructions { + name: "reduce-scatter_stage-1_dim-0_bidir-ring_ccw_unidir-ring_reduction_3" + opcode: "call" + instruction_id: 105 + operand_ids: 104 + inner_subroutines { + name: "reduction_subroutine_phase_3" + subroutine_root_id: 108 + execution_probability: 1 + execution_count: 1 + instructions { + name: "op1_phase_3" + opcode: "delay" + instruction_id: 106 + bytes_out: 10 + } + instructions { + name: "op2_phase_3" + opcode: "delay" + instruction_id: 107 + bytes_out: 10 + } + instructions { + name: "sum_phase_3" + opcode: "delay" + instruction_id: 108 + ops: 20 + operand_ids: 106 + operand_ids: 107 + } + } + } + } + } + instructions { + name: "reduce-scatter_stage-1_dim-0_bidir-ring_root_1" + opcode: "null" + instruction_id: 109 + operand_ids: 77 + operand_ids: 93 + } + } + } + instructions { + name: "reduce-scatter_stage-1_dim-1" + opcode: "reduce-scatter" + instruction_id: 110 + bytes_out: 80 + communication_groups { + group_ids: 0 + group_ids: 1 + group_ids: 4 + group_ids: 5 + } + operand_ids: 75 + inner_subroutines { + name: "reduce-scatter_stage-1_dim-1_bidir-ring" + subroutine_root_id: 143 + execution_probability: 1 + execution_count: 1 + instructions { + name: "reduce-scatter_stage-1_dim-1_bidir-ring_cw" + opcode: "reduce-scatter" + instruction_id: 111 + bytes_out: 40 + communication_groups { + group_ids: 0 + group_ids: 1 + group_ids: 4 + group_ids: 5 + } + inner_subroutines { + name: "reduce-scatter_stage-1_dim-1_bidir-ring_cw_unidir-ring" + subroutine_root_id: 123 + execution_probability: 1 + execution_count: 1 + instructions { + name: "reduce-scatter_stage-1_dim-1_bidir-ring_cw_unidir-ring_sendrecv_1" + opcode: "sendrecv" + instruction_id: 112 + bytes_in: 10 + bytes_out: 10 + communication_groups { + group_ids: 0 + group_ids: 4 + } + } + instructions { + name: "reduce-scatter_stage-1_dim-1_bidir-ring_cw_unidir-ring_reduction_1" + opcode: "call" + instruction_id: 113 + operand_ids: 112 + inner_subroutines { + name: "reduction_subroutine_phase_1" + subroutine_root_id: 116 + execution_probability: 1 + execution_count: 1 + instructions { + name: "op1_phase_1" + opcode: "delay" + instruction_id: 114 + bytes_out: 10 + } + instructions { + name: "op2_phase_1" + opcode: "delay" + instruction_id: 115 + bytes_out: 10 + } + instructions { + name: "sum_phase_1" + opcode: "delay" + instruction_id: 116 + ops: 20 + operand_ids: 114 + operand_ids: 115 + } + } + } + instructions { + name: "reduce-scatter_stage-1_dim-1_bidir-ring_cw_unidir-ring_sendrecv_2" + opcode: "sendrecv" + instruction_id: 117 + bytes_in: 10 + bytes_out: 10 + communication_groups { + group_ids: 0 + group_ids: 4 + } + operand_ids: 113 + } + instructions { + name: "reduce-scatter_stage-1_dim-1_bidir-ring_cw_unidir-ring_reduction_2" + opcode: "call" + instruction_id: 118 + operand_ids: 117 + inner_subroutines { + name: "reduction_subroutine_phase_2" + subroutine_root_id: 121 + execution_probability: 1 + execution_count: 1 + instructions { + name: "op1_phase_2" + opcode: "delay" + instruction_id: 119 + bytes_out: 10 + } + instructions { + name: "op2_phase_2" + opcode: "delay" + instruction_id: 120 + bytes_out: 10 + } + instructions { + name: "sum_phase_2" + opcode: "delay" + instruction_id: 121 + ops: 20 + operand_ids: 119 + operand_ids: 120 + } + } + } + instructions { + name: "reduce-scatter_stage-1_dim-1_bidir-ring_cw_unidir-ring_sendrecv_3" + opcode: "sendrecv" + instruction_id: 122 + bytes_in: 10 + bytes_out: 10 + communication_groups { + group_ids: 0 + group_ids: 4 + } + operand_ids: 118 + } + instructions { + name: "reduce-scatter_stage-1_dim-1_bidir-ring_cw_unidir-ring_reduction_3" + opcode: "call" + instruction_id: 123 + operand_ids: 122 + inner_subroutines { + name: "reduction_subroutine_phase_3" + subroutine_root_id: 126 + execution_probability: 1 + execution_count: 1 + instructions { + name: "op1_phase_3" + opcode: "delay" + instruction_id: 124 + bytes_out: 10 + } + instructions { + name: "op2_phase_3" + opcode: "delay" + instruction_id: 125 + bytes_out: 10 + } + instructions { + name: "sum_phase_3" + opcode: "delay" + instruction_id: 126 + ops: 20 + operand_ids: 124 + operand_ids: 125 + } + } + } + } + } + instructions { + name: "reduce-scatter_stage-1_dim-1_bidir-ring_ccw" + opcode: "reduce-scatter" + instruction_id: 127 + bytes_out: 40 + communication_groups { + group_ids: 5 + group_ids: 4 + group_ids: 1 + group_ids: 0 + } + inner_subroutines { + name: "reduce-scatter_stage-1_dim-1_bidir-ring_ccw_unidir-ring" + subroutine_root_id: 139 + execution_probability: 1 + execution_count: 1 + instructions { + name: "reduce-scatter_stage-1_dim-1_bidir-ring_ccw_unidir-ring_sendrecv_1" + opcode: "sendrecv" + instruction_id: 128 + bytes_in: 10 + bytes_out: 10 + communication_groups { + group_ids: 4 + group_ids: 0 + } + } + instructions { + name: "reduce-scatter_stage-1_dim-1_bidir-ring_ccw_unidir-ring_reduction_1" + opcode: "call" + instruction_id: 129 + operand_ids: 128 + inner_subroutines { + name: "reduction_subroutine_phase_1" + subroutine_root_id: 132 + execution_probability: 1 + execution_count: 1 + instructions { + name: "op1_phase_1" + opcode: "delay" + instruction_id: 130 + bytes_out: 10 + } + instructions { + name: "op2_phase_1" + opcode: "delay" + instruction_id: 131 + bytes_out: 10 + } + instructions { + name: "sum_phase_1" + opcode: "delay" + instruction_id: 132 + ops: 20 + operand_ids: 130 + operand_ids: 131 + } + } + } + instructions { + name: "reduce-scatter_stage-1_dim-1_bidir-ring_ccw_unidir-ring_sendrecv_2" + opcode: "sendrecv" + instruction_id: 133 + bytes_in: 10 + bytes_out: 10 + communication_groups { + group_ids: 4 + group_ids: 0 + } + operand_ids: 129 + } + instructions { + name: "reduce-scatter_stage-1_dim-1_bidir-ring_ccw_unidir-ring_reduction_2" + opcode: "call" + instruction_id: 134 + operand_ids: 133 + inner_subroutines { + name: "reduction_subroutine_phase_2" + subroutine_root_id: 137 + execution_probability: 1 + execution_count: 1 + instructions { + name: "op1_phase_2" + opcode: "delay" + instruction_id: 135 + bytes_out: 10 + } + instructions { + name: "op2_phase_2" + opcode: "delay" + instruction_id: 136 + bytes_out: 10 + } + instructions { + name: "sum_phase_2" + opcode: "delay" + instruction_id: 137 + ops: 20 + operand_ids: 135 + operand_ids: 136 + } + } + } + instructions { + name: "reduce-scatter_stage-1_dim-1_bidir-ring_ccw_unidir-ring_sendrecv_3" + opcode: "sendrecv" + instruction_id: 138 + bytes_in: 10 + bytes_out: 10 + communication_groups { + group_ids: 4 + group_ids: 0 + } + operand_ids: 134 + } + instructions { + name: "reduce-scatter_stage-1_dim-1_bidir-ring_ccw_unidir-ring_reduction_3" + opcode: "call" + instruction_id: 139 + operand_ids: 138 + inner_subroutines { + name: "reduction_subroutine_phase_3" + subroutine_root_id: 142 + execution_probability: 1 + execution_count: 1 + instructions { + name: "op1_phase_3" + opcode: "delay" + instruction_id: 140 + bytes_out: 10 + } + instructions { + name: "op2_phase_3" + opcode: "delay" + instruction_id: 141 + bytes_out: 10 + } + instructions { + name: "sum_phase_3" + opcode: "delay" + instruction_id: 142 + ops: 20 + operand_ids: 140 + operand_ids: 141 + } + } + } + } + } + instructions { + name: "reduce-scatter_stage-1_dim-1_bidir-ring_root_1" + opcode: "null" + instruction_id: 143 + operand_ids: 111 + operand_ids: 127 + } + } + } + instructions { + name: "reduce-scatter_stage-1_root" + opcode: "null" + instruction_id: 144 + operand_ids: 76 + operand_ids: 110 + } +} + )proto"; + google::protobuf::TextFormat::ParseFromString(test_str, + &proto); + return proto; +} // NOLINT + +// Tests expanding 2D-Torus reduce-scatter +TEST(Torus2dReduceScatter, NoBarrier) { + auto graph = absl::make_unique("test_graph", 1); + auto sub = absl::make_unique( + "test_subroutine", graph.get()); + auto sub_ptr = sub.get(); graph->SetEntrySubroutine(std::move(sub)); ASSERT_OK_AND_ASSIGN(auto instr_1, paragraph::Instruction::Create( @@ -578,592 +1407,967 @@ TEST(Torus2dReduceScatter, WithBarrier) { "algorithm": "torus-2d", "concentration": 2, "dimension_widths": [2, 2], - "barrier": { - "algorithm": "centralized" + "integrated_local_exchange": true + } + } + )"_json; + + ASSERT_OK_AND_ASSIGN(auto translators, paragraph::CreateTranslators( + paragraph::TranslatorType::kCollective, config)); + EXPECT_OK(translators["reduce-scatter"]->Translate(reducescatter)); + + paragraph::InstructionProto reducescatter_proto = no_barrier_test_proto(); + EXPECT_TRUE(google::protobuf::util::MessageDifferencer::Equals( + reducescatter->ToProto().value(), reducescatter_proto)); +} + +paragraph::InstructionProto with_barrier_test_proto() { + paragraph::InstructionProto proto; + std::string test_str = + R"proto( +name: "reduce-scatter" +opcode: "reduce-scatter" +instruction_id: 2 +bytes_out: 80 +communication_groups { + group_ids: 0 + group_ids: 1 + group_ids: 2 + group_ids: 3 + group_ids: 4 + group_ids: 5 + group_ids: 6 + group_ids: 7 +} +inner_subroutines { + name: "reduce-scatter_torus-2d" + subroutine_root_id: 79 + execution_probability: 1 + execution_count: 1 + instructions { + name: "reduce-scatter_stage-0_dim-0" + opcode: "reduce-scatter" + instruction_id: 7 + bytes_out: 20 + communication_groups { + group_ids: 0 + group_ids: 2 + } + inner_subroutines { + name: "reduce-scatter_stage-0_dim-0_bidir-ring" + subroutine_root_id: 23 + execution_probability: 1 + execution_count: 1 + instructions { + name: "reduce-scatter_stage-0_dim-0_bidir-ring_barrier" + opcode: "barrier" + instruction_id: 8 + communication_groups { + group_ids: 0 + group_ids: 2 + } + inner_subroutines { + name: "reduce-scatter_stage-0_dim-0_bidir-ring_barrier_centralized" + subroutine_root_id: 10 + execution_probability: 1 + execution_count: 1 + instructions { + name: "reduce-scatter_stage-0_dim-0_bidir-ring_barrier_centralized_send_to_0" + opcode: "send" + instruction_id: 9 + communication_groups { + group_ids: 0 + } + } + instructions { + name: "reduce-scatter_stage-0_dim-0_bidir-ring_barrier_centralized_recv_from_0" + opcode: "recv" + instruction_id: 10 + communication_groups { + group_ids: 0 + } + operand_ids: 9 + } + } + } + instructions { + name: "reduce-scatter_stage-0_dim-0_bidir-ring_cw" + opcode: "reduce-scatter" + instruction_id: 11 + bytes_out: 10 + communication_groups { + group_ids: 0 + group_ids: 2 + } + operand_ids: 8 + inner_subroutines { + name: "reduce-scatter_stage-0_dim-0_bidir-ring_cw_unidir-ring" + subroutine_root_id: 13 + execution_probability: 1 + execution_count: 1 + instructions { + name: "reduce-scatter_stage-0_dim-0_bidir-ring_cw_unidir-ring_sendrecv_1" + opcode: "sendrecv" + instruction_id: 12 + bytes_in: 5 + bytes_out: 5 + communication_groups { + group_ids: 0 + group_ids: 0 + } + } + instructions { + name: "reduce-scatter_stage-0_dim-0_bidir-ring_cw_unidir-ring_reduction_1" + opcode: "call" + instruction_id: 13 + operand_ids: 12 + inner_subroutines { + name: "reduction_subroutine_phase_1" + subroutine_root_id: 16 + execution_probability: 1 + execution_count: 1 + instructions { + name: "op1_phase_1" + opcode: "delay" + instruction_id: 14 + bytes_out: 5 + } + instructions { + name: "op2_phase_1" + opcode: "delay" + instruction_id: 15 + bytes_out: 5 + } + instructions { + name: "sum_phase_1" + opcode: "delay" + instruction_id: 16 + ops: 10 + operand_ids: 14 + operand_ids: 15 + } + } + } + } + } + instructions { + name: "reduce-scatter_stage-0_dim-0_bidir-ring_ccw" + opcode: "reduce-scatter" + instruction_id: 17 + bytes_out: 10 + communication_groups { + group_ids: 2 + group_ids: 0 + } + operand_ids: 8 + inner_subroutines { + name: "reduce-scatter_stage-0_dim-0_bidir-ring_ccw_unidir-ring" + subroutine_root_id: 19 + execution_probability: 1 + execution_count: 1 + instructions { + name: "reduce-scatter_stage-0_dim-0_bidir-ring_ccw_unidir-ring_sendrecv_1" + opcode: "sendrecv" + instruction_id: 18 + bytes_in: 5 + bytes_out: 5 + communication_groups { + group_ids: 0 + group_ids: 0 + } + } + instructions { + name: "reduce-scatter_stage-0_dim-0_bidir-ring_ccw_unidir-ring_reduction_1" + opcode: "call" + instruction_id: 19 + operand_ids: 18 + inner_subroutines { + name: "reduction_subroutine_phase_1" + subroutine_root_id: 22 + execution_probability: 1 + execution_count: 1 + instructions { + name: "op1_phase_1" + opcode: "delay" + instruction_id: 20 + bytes_out: 5 + } + instructions { + name: "op2_phase_1" + opcode: "delay" + instruction_id: 21 + bytes_out: 5 + } + instructions { + name: "sum_phase_1" + opcode: "delay" + instruction_id: 22 + ops: 10 + operand_ids: 20 + operand_ids: 21 + } + } + } + } + } + instructions { + name: "reduce-scatter_stage-0_dim-0_bidir-ring_root_2" + opcode: "null" + instruction_id: 23 + operand_ids: 11 + operand_ids: 17 + } + } + } + instructions { + name: "reduce-scatter_stage-0_dim-1" + opcode: "reduce-scatter" + instruction_id: 24 + bytes_out: 20 + communication_groups { + group_ids: 2 + group_ids: 6 + } + inner_subroutines { + name: "reduce-scatter_stage-0_dim-1_bidir-ring" + subroutine_root_id: 41 + execution_probability: 1 + execution_count: 1 + instructions { + name: "reduce-scatter_stage-0_dim-1_bidir-ring_barrier" + opcode: "barrier" + instruction_id: 25 + communication_groups { + group_ids: 2 + group_ids: 6 + } + inner_subroutines { + name: "reduce-scatter_stage-0_dim-1_bidir-ring_barrier_centralized" + subroutine_root_id: 28 + execution_probability: 1 + execution_count: 1 + instructions { + name: "reduce-scatter_stage-0_dim-1_bidir-ring_barrier_centralized_coordinator_recv_from_6" + opcode: "recv" + instruction_id: 26 + communication_groups { + group_ids: 6 + } + } + instructions { + name: "reduce-scatter_stage-0_dim-1_bidir-ring_barrier_centralized_coordinator_send_to_6" + opcode: "send" + instruction_id: 27 + communication_groups { + group_ids: 6 + } + operand_ids: 26 + } + instructions { + name: "reduce-scatter_stage-0_dim-1_bidir-ring_barrier_centralized_root_2" + opcode: "null" + instruction_id: 28 + operand_ids: 27 + } + } + } + instructions { + name: "reduce-scatter_stage-0_dim-1_bidir-ring_cw" + opcode: "reduce-scatter" + instruction_id: 29 + bytes_out: 10 + communication_groups { + group_ids: 2 + group_ids: 6 + } + operand_ids: 25 + inner_subroutines { + name: "reduce-scatter_stage-0_dim-1_bidir-ring_cw_unidir-ring" + subroutine_root_id: 31 + execution_probability: 1 + execution_count: 1 + instructions { + name: "reduce-scatter_stage-0_dim-1_bidir-ring_cw_unidir-ring_sendrecv_1" + opcode: "sendrecv" + instruction_id: 30 + bytes_in: 5 + bytes_out: 5 + communication_groups { + group_ids: 6 + group_ids: 6 + } + } + instructions { + name: "reduce-scatter_stage-0_dim-1_bidir-ring_cw_unidir-ring_reduction_1" + opcode: "call" + instruction_id: 31 + operand_ids: 30 + inner_subroutines { + name: "reduction_subroutine_phase_1" + subroutine_root_id: 34 + execution_probability: 1 + execution_count: 1 + instructions { + name: "op1_phase_1" + opcode: "delay" + instruction_id: 32 + bytes_out: 5 + } + instructions { + name: "op2_phase_1" + opcode: "delay" + instruction_id: 33 + bytes_out: 5 + } + instructions { + name: "sum_phase_1" + opcode: "delay" + instruction_id: 34 + ops: 10 + operand_ids: 32 + operand_ids: 33 + } + } + } + } + } + instructions { + name: "reduce-scatter_stage-0_dim-1_bidir-ring_ccw" + opcode: "reduce-scatter" + instruction_id: 35 + bytes_out: 10 + communication_groups { + group_ids: 6 + group_ids: 2 + } + operand_ids: 25 + inner_subroutines { + name: "reduce-scatter_stage-0_dim-1_bidir-ring_ccw_unidir-ring" + subroutine_root_id: 37 + execution_probability: 1 + execution_count: 1 + instructions { + name: "reduce-scatter_stage-0_dim-1_bidir-ring_ccw_unidir-ring_sendrecv_1" + opcode: "sendrecv" + instruction_id: 36 + bytes_in: 5 + bytes_out: 5 + communication_groups { + group_ids: 6 + group_ids: 6 + } + } + instructions { + name: "reduce-scatter_stage-0_dim-1_bidir-ring_ccw_unidir-ring_reduction_1" + opcode: "call" + instruction_id: 37 + operand_ids: 36 + inner_subroutines { + name: "reduction_subroutine_phase_1" + subroutine_root_id: 40 + execution_probability: 1 + execution_count: 1 + instructions { + name: "op1_phase_1" + opcode: "delay" + instruction_id: 38 + bytes_out: 5 + } + instructions { + name: "op2_phase_1" + opcode: "delay" + instruction_id: 39 + bytes_out: 5 + } + instructions { + name: "sum_phase_1" + opcode: "delay" + instruction_id: 40 + ops: 10 + operand_ids: 38 + operand_ids: 39 + } + } + } } } + instructions { + name: "reduce-scatter_stage-0_dim-1_bidir-ring_root_2" + opcode: "null" + instruction_id: 41 + operand_ids: 29 + operand_ids: 35 + } } - )"_json; - - ASSERT_OK_AND_ASSIGN(auto translators, paragraph::CreateTranslators( - paragraph::TranslatorType::kCollective, config)); - EXPECT_OK(translators["reduce-scatter"]->Translate(reducescatter)); - - paragraph::InstructionProto reducescatter_proto; - std::string reducescatter_str = - R"proto( -name: "reduce-scatter" -opcode: "reduce-scatter" -instruction_id: 2 -bytes_out: 80 -communication_groups { - group_ids: 0 - group_ids: 1 - group_ids: 2 - group_ids: 3 - group_ids: 4 - group_ids: 5 - group_ids: 6 - group_ids: 7 -} -inner_subroutines { - name: "reduce-scatter_torus-2d" - subroutine_root_id: 42 - execution_probability: 1 - execution_count: 1 + } + instructions { + name: "reduce-scatter_stage-0_root" + opcode: "null" + instruction_id: 42 + operand_ids: 7 + operand_ids: 24 + } instructions { - name: "reduce-scatter_dim-conc" + name: "reduce-scatter_stage-1_dim-0" opcode: "reduce-scatter" - instruction_id: 7 - bytes_out: 20 + instruction_id: 43 + bytes_out: 40 communication_groups { + group_ids: 0 group_ids: 2 - group_ids: 3 } + operand_ids: 42 inner_subroutines { - name: "reduce-scatter_dim-conc_bidir-ring" - subroutine_root_id: 24 + name: "reduce-scatter_stage-1_dim-0_bidir-ring" + subroutine_root_id: 59 execution_probability: 1 execution_count: 1 instructions { - name: "reduce-scatter_dim-conc_bidir-ring_barrier" + name: "reduce-scatter_stage-1_dim-0_bidir-ring_barrier" opcode: "barrier" - instruction_id: 8 + instruction_id: 44 communication_groups { + group_ids: 0 group_ids: 2 - group_ids: 3 } inner_subroutines { - name: "reduce-scatter_dim-conc_bidir-ring_barrier_centralized" - subroutine_root_id: 11 + name: "reduce-scatter_stage-1_dim-0_bidir-ring_barrier_centralized" + subroutine_root_id: 46 execution_probability: 1 execution_count: 1 instructions { - name: "reduce-scatter_dim-conc_bidir-ring_barrier_centralized_coordinator_recv_from_3" - opcode: "recv" - instruction_id: 9 + name: "reduce-scatter_stage-1_dim-0_bidir-ring_barrier_centralized_send_to_0" + opcode: "send" + instruction_id: 45 communication_groups { - group_ids: 3 + group_ids: 0 } } instructions { - name: "reduce-scatter_dim-conc_bidir-ring_barrier_centralized_coordinator_send_to_3" - opcode: "send" - instruction_id: 10 + name: "reduce-scatter_stage-1_dim-0_bidir-ring_barrier_centralized_recv_from_0" + opcode: "recv" + instruction_id: 46 communication_groups { - group_ids: 3 + group_ids: 0 } - operand_ids: 9 - } - instructions { - name: "reduce-scatter_dim-conc_bidir-ring_barrier_centralized_root_2" - opcode: "null" - instruction_id: 11 - operand_ids: 10 + operand_ids: 45 } } } instructions { - name: "reduce-scatter_dim-conc_bidir-ring_cw" + name: "reduce-scatter_stage-1_dim-0_bidir-ring_cw" opcode: "reduce-scatter" - instruction_id: 12 - bytes_out: 10 + instruction_id: 47 + bytes_out: 20 communication_groups { + group_ids: 0 group_ids: 2 - group_ids: 3 } - operand_ids: 8 + operand_ids: 44 inner_subroutines { - name: "reduce-scatter_dim-conc_bidir-ring_cw_unidir-ring" - subroutine_root_id: 14 + name: "reduce-scatter_stage-1_dim-0_bidir-ring_cw_unidir-ring" + subroutine_root_id: 49 execution_probability: 1 execution_count: 1 instructions { - name: "reduce-scatter_dim-conc_bidir-ring_cw_unidir-ring_sendrecv_1" + name: "reduce-scatter_stage-1_dim-0_bidir-ring_cw_unidir-ring_sendrecv_1" opcode: "sendrecv" - instruction_id: 13 - bytes_in: 5 - bytes_out: 5 + instruction_id: 48 + bytes_in: 10 + bytes_out: 10 communication_groups { - group_ids: 3 - group_ids: 3 + group_ids: 0 + group_ids: 0 } } instructions { - name: "reduce-scatter_dim-conc_bidir-ring_cw_unidir-ring_reduction_1" + name: "reduce-scatter_stage-1_dim-0_bidir-ring_cw_unidir-ring_reduction_1" opcode: "call" - instruction_id: 14 - operand_ids: 13 + instruction_id: 49 + operand_ids: 48 inner_subroutines { name: "reduction_subroutine_phase_1" - subroutine_root_id: 17 + subroutine_root_id: 52 execution_probability: 1 execution_count: 1 instructions { name: "op1_phase_1" opcode: "delay" - instruction_id: 15 - bytes_out: 5 + instruction_id: 50 + bytes_out: 10 } instructions { name: "op2_phase_1" opcode: "delay" - instruction_id: 16 - bytes_out: 5 + instruction_id: 51 + bytes_out: 10 } instructions { name: "sum_phase_1" opcode: "delay" - instruction_id: 17 - ops: 10 - operand_ids: 15 - operand_ids: 16 + instruction_id: 52 + ops: 20 + operand_ids: 50 + operand_ids: 51 } } } } } instructions { - name: "reduce-scatter_dim-conc_bidir-ring_ccw" + name: "reduce-scatter_stage-1_dim-0_bidir-ring_ccw" opcode: "reduce-scatter" - instruction_id: 18 - bytes_out: 10 + instruction_id: 53 + bytes_out: 20 communication_groups { - group_ids: 3 group_ids: 2 + group_ids: 0 } - operand_ids: 8 + operand_ids: 44 inner_subroutines { - name: "reduce-scatter_dim-conc_bidir-ring_ccw_unidir-ring" - subroutine_root_id: 20 + name: "reduce-scatter_stage-1_dim-0_bidir-ring_ccw_unidir-ring" + subroutine_root_id: 55 execution_probability: 1 execution_count: 1 instructions { - name: "reduce-scatter_dim-conc_bidir-ring_ccw_unidir-ring_sendrecv_1" + name: "reduce-scatter_stage-1_dim-0_bidir-ring_ccw_unidir-ring_sendrecv_1" opcode: "sendrecv" - instruction_id: 19 - bytes_in: 5 - bytes_out: 5 + instruction_id: 54 + bytes_in: 10 + bytes_out: 10 communication_groups { - group_ids: 3 - group_ids: 3 + group_ids: 0 + group_ids: 0 } } instructions { - name: "reduce-scatter_dim-conc_bidir-ring_ccw_unidir-ring_reduction_1" + name: "reduce-scatter_stage-1_dim-0_bidir-ring_ccw_unidir-ring_reduction_1" opcode: "call" - instruction_id: 20 - operand_ids: 19 + instruction_id: 55 + operand_ids: 54 inner_subroutines { name: "reduction_subroutine_phase_1" - subroutine_root_id: 23 + subroutine_root_id: 58 execution_probability: 1 execution_count: 1 instructions { name: "op1_phase_1" opcode: "delay" - instruction_id: 21 - bytes_out: 5 + instruction_id: 56 + bytes_out: 10 } instructions { name: "op2_phase_1" opcode: "delay" - instruction_id: 22 - bytes_out: 5 + instruction_id: 57 + bytes_out: 10 } instructions { name: "sum_phase_1" opcode: "delay" - instruction_id: 23 - ops: 10 - operand_ids: 21 - operand_ids: 22 + instruction_id: 58 + ops: 20 + operand_ids: 56 + operand_ids: 57 } } } } } instructions { - name: "reduce-scatter_dim-conc_bidir-ring_root_2" + name: "reduce-scatter_stage-1_dim-0_bidir-ring_root_2" opcode: "null" - instruction_id: 24 - operand_ids: 12 - operand_ids: 18 + instruction_id: 59 + operand_ids: 47 + operand_ids: 53 } } } instructions { - name: "reduce-scatter_dim-0" + name: "reduce-scatter_stage-1_dim-1" opcode: "reduce-scatter" - instruction_id: 25 - bytes_out: 20 + instruction_id: 60 + bytes_out: 40 communication_groups { - group_ids: 0 group_ids: 2 + group_ids: 6 } - operand_ids: 7 + operand_ids: 42 inner_subroutines { - name: "reduce-scatter_dim-0_bidir-ring" - subroutine_root_id: 41 + name: "reduce-scatter_stage-1_dim-1_bidir-ring" + subroutine_root_id: 77 execution_probability: 1 execution_count: 1 instructions { - name: "reduce-scatter_dim-0_bidir-ring_barrier" + name: "reduce-scatter_stage-1_dim-1_bidir-ring_barrier" opcode: "barrier" - instruction_id: 26 + instruction_id: 61 communication_groups { - group_ids: 0 group_ids: 2 + group_ids: 6 } inner_subroutines { - name: "reduce-scatter_dim-0_bidir-ring_barrier_centralized" - subroutine_root_id: 28 + name: "reduce-scatter_stage-1_dim-1_bidir-ring_barrier_centralized" + subroutine_root_id: 64 execution_probability: 1 execution_count: 1 instructions { - name: "reduce-scatter_dim-0_bidir-ring_barrier_centralized_send_to_0" - opcode: "send" - instruction_id: 27 + name: "reduce-scatter_stage-1_dim-1_bidir-ring_barrier_centralized_coordinator_recv_from_6" + opcode: "recv" + instruction_id: 62 communication_groups { - group_ids: 0 + group_ids: 6 } } instructions { - name: "reduce-scatter_dim-0_bidir-ring_barrier_centralized_recv_from_0" - opcode: "recv" - instruction_id: 28 + name: "reduce-scatter_stage-1_dim-1_bidir-ring_barrier_centralized_coordinator_send_to_6" + opcode: "send" + instruction_id: 63 communication_groups { - group_ids: 0 + group_ids: 6 } - operand_ids: 27 + operand_ids: 62 + } + instructions { + name: "reduce-scatter_stage-1_dim-1_bidir-ring_barrier_centralized_root_2" + opcode: "null" + instruction_id: 64 + operand_ids: 63 } } } instructions { - name: "reduce-scatter_dim-0_bidir-ring_cw" + name: "reduce-scatter_stage-1_dim-1_bidir-ring_cw" opcode: "reduce-scatter" - instruction_id: 29 - bytes_out: 10 + instruction_id: 65 + bytes_out: 20 communication_groups { - group_ids: 0 group_ids: 2 + group_ids: 6 } - operand_ids: 26 + operand_ids: 61 inner_subroutines { - name: "reduce-scatter_dim-0_bidir-ring_cw_unidir-ring" - subroutine_root_id: 31 + name: "reduce-scatter_stage-1_dim-1_bidir-ring_cw_unidir-ring" + subroutine_root_id: 67 execution_probability: 1 execution_count: 1 instructions { - name: "reduce-scatter_dim-0_bidir-ring_cw_unidir-ring_sendrecv_1" + name: "reduce-scatter_stage-1_dim-1_bidir-ring_cw_unidir-ring_sendrecv_1" opcode: "sendrecv" - instruction_id: 30 - bytes_in: 5 - bytes_out: 5 + instruction_id: 66 + bytes_in: 10 + bytes_out: 10 communication_groups { - group_ids: 0 - group_ids: 0 + group_ids: 6 + group_ids: 6 } } instructions { - name: "reduce-scatter_dim-0_bidir-ring_cw_unidir-ring_reduction_1" + name: "reduce-scatter_stage-1_dim-1_bidir-ring_cw_unidir-ring_reduction_1" opcode: "call" - instruction_id: 31 - operand_ids: 30 + instruction_id: 67 + operand_ids: 66 inner_subroutines { name: "reduction_subroutine_phase_1" - subroutine_root_id: 34 + subroutine_root_id: 70 execution_probability: 1 execution_count: 1 instructions { name: "op1_phase_1" opcode: "delay" - instruction_id: 32 - bytes_out: 5 + instruction_id: 68 + bytes_out: 10 } instructions { name: "op2_phase_1" opcode: "delay" - instruction_id: 33 - bytes_out: 5 + instruction_id: 69 + bytes_out: 10 } instructions { name: "sum_phase_1" opcode: "delay" - instruction_id: 34 - ops: 10 - operand_ids: 32 - operand_ids: 33 + instruction_id: 70 + ops: 20 + operand_ids: 68 + operand_ids: 69 } } } } } instructions { - name: "reduce-scatter_dim-0_bidir-ring_ccw" + name: "reduce-scatter_stage-1_dim-1_bidir-ring_ccw" opcode: "reduce-scatter" - instruction_id: 35 - bytes_out: 10 + instruction_id: 71 + bytes_out: 20 communication_groups { + group_ids: 6 group_ids: 2 - group_ids: 0 } - operand_ids: 26 + operand_ids: 61 inner_subroutines { - name: "reduce-scatter_dim-0_bidir-ring_ccw_unidir-ring" - subroutine_root_id: 37 + name: "reduce-scatter_stage-1_dim-1_bidir-ring_ccw_unidir-ring" + subroutine_root_id: 73 execution_probability: 1 execution_count: 1 instructions { - name: "reduce-scatter_dim-0_bidir-ring_ccw_unidir-ring_sendrecv_1" + name: "reduce-scatter_stage-1_dim-1_bidir-ring_ccw_unidir-ring_sendrecv_1" opcode: "sendrecv" - instruction_id: 36 - bytes_in: 5 - bytes_out: 5 + instruction_id: 72 + bytes_in: 10 + bytes_out: 10 communication_groups { - group_ids: 0 - group_ids: 0 + group_ids: 6 + group_ids: 6 } } instructions { - name: "reduce-scatter_dim-0_bidir-ring_ccw_unidir-ring_reduction_1" + name: "reduce-scatter_stage-1_dim-1_bidir-ring_ccw_unidir-ring_reduction_1" opcode: "call" - instruction_id: 37 - operand_ids: 36 + instruction_id: 73 + operand_ids: 72 inner_subroutines { name: "reduction_subroutine_phase_1" - subroutine_root_id: 40 + subroutine_root_id: 76 execution_probability: 1 execution_count: 1 instructions { name: "op1_phase_1" opcode: "delay" - instruction_id: 38 - bytes_out: 5 + instruction_id: 74 + bytes_out: 10 } instructions { name: "op2_phase_1" opcode: "delay" - instruction_id: 39 - bytes_out: 5 + instruction_id: 75 + bytes_out: 10 } instructions { name: "sum_phase_1" opcode: "delay" - instruction_id: 40 - ops: 10 - operand_ids: 38 - operand_ids: 39 + instruction_id: 76 + ops: 20 + operand_ids: 74 + operand_ids: 75 } } } } } instructions { - name: "reduce-scatter_dim-0_bidir-ring_root_2" + name: "reduce-scatter_stage-1_dim-1_bidir-ring_root_2" opcode: "null" - instruction_id: 41 - operand_ids: 29 - operand_ids: 35 + instruction_id: 77 + operand_ids: 65 + operand_ids: 71 } } } instructions { - name: "reduce-scatter_dim-1" + name: "reduce-scatter_stage-1_root" + opcode: "null" + instruction_id: 78 + operand_ids: 43 + operand_ids: 60 + } + instructions { + name: "reduce-scatter_conc" opcode: "reduce-scatter" - instruction_id: 42 - bytes_out: 20 + instruction_id: 79 + bytes_out: 80 communication_groups { group_ids: 2 - group_ids: 6 + group_ids: 3 } - operand_ids: 25 + operand_ids: 78 inner_subroutines { - name: "reduce-scatter_dim-1_bidir-ring" - subroutine_root_id: 59 + name: "reduce-scatter_conc_bidir-ring" + subroutine_root_id: 96 execution_probability: 1 execution_count: 1 instructions { - name: "reduce-scatter_dim-1_bidir-ring_barrier" + name: "reduce-scatter_conc_bidir-ring_barrier" opcode: "barrier" - instruction_id: 43 + instruction_id: 80 communication_groups { group_ids: 2 - group_ids: 6 + group_ids: 3 } inner_subroutines { - name: "reduce-scatter_dim-1_bidir-ring_barrier_centralized" - subroutine_root_id: 46 + name: "reduce-scatter_conc_bidir-ring_barrier_centralized" + subroutine_root_id: 83 execution_probability: 1 execution_count: 1 instructions { - name: "reduce-scatter_dim-1_bidir-ring_barrier_centralized_coordinator_recv_from_6" + name: "reduce-scatter_conc_bidir-ring_barrier_centralized_coordinator_recv_from_3" opcode: "recv" - instruction_id: 44 + instruction_id: 81 communication_groups { - group_ids: 6 + group_ids: 3 } } instructions { - name: "reduce-scatter_dim-1_bidir-ring_barrier_centralized_coordinator_send_to_6" + name: "reduce-scatter_conc_bidir-ring_barrier_centralized_coordinator_send_to_3" opcode: "send" - instruction_id: 45 + instruction_id: 82 communication_groups { - group_ids: 6 + group_ids: 3 } - operand_ids: 44 + operand_ids: 81 } instructions { - name: "reduce-scatter_dim-1_bidir-ring_barrier_centralized_root_2" + name: "reduce-scatter_conc_bidir-ring_barrier_centralized_root_2" opcode: "null" - instruction_id: 46 - operand_ids: 45 + instruction_id: 83 + operand_ids: 82 } } } instructions { - name: "reduce-scatter_dim-1_bidir-ring_cw" + name: "reduce-scatter_conc_bidir-ring_cw" opcode: "reduce-scatter" - instruction_id: 47 - bytes_out: 10 + instruction_id: 84 + bytes_out: 40 communication_groups { group_ids: 2 - group_ids: 6 + group_ids: 3 } - operand_ids: 43 + operand_ids: 80 inner_subroutines { - name: "reduce-scatter_dim-1_bidir-ring_cw_unidir-ring" - subroutine_root_id: 49 + name: "reduce-scatter_conc_bidir-ring_cw_unidir-ring" + subroutine_root_id: 86 execution_probability: 1 execution_count: 1 instructions { - name: "reduce-scatter_dim-1_bidir-ring_cw_unidir-ring_sendrecv_1" + name: "reduce-scatter_conc_bidir-ring_cw_unidir-ring_sendrecv_1" opcode: "sendrecv" - instruction_id: 48 - bytes_in: 5 - bytes_out: 5 + instruction_id: 85 + bytes_in: 20 + bytes_out: 20 communication_groups { - group_ids: 6 - group_ids: 6 + group_ids: 3 + group_ids: 3 } } instructions { - name: "reduce-scatter_dim-1_bidir-ring_cw_unidir-ring_reduction_1" + name: "reduce-scatter_conc_bidir-ring_cw_unidir-ring_reduction_1" opcode: "call" - instruction_id: 49 - operand_ids: 48 + instruction_id: 86 + operand_ids: 85 inner_subroutines { name: "reduction_subroutine_phase_1" - subroutine_root_id: 52 + subroutine_root_id: 89 execution_probability: 1 execution_count: 1 instructions { name: "op1_phase_1" opcode: "delay" - instruction_id: 50 - bytes_out: 5 + instruction_id: 87 + bytes_out: 20 } instructions { name: "op2_phase_1" opcode: "delay" - instruction_id: 51 - bytes_out: 5 + instruction_id: 88 + bytes_out: 20 } instructions { name: "sum_phase_1" opcode: "delay" - instruction_id: 52 - ops: 10 - operand_ids: 50 - operand_ids: 51 + instruction_id: 89 + ops: 40 + operand_ids: 87 + operand_ids: 88 } } } } } instructions { - name: "reduce-scatter_dim-1_bidir-ring_ccw" + name: "reduce-scatter_conc_bidir-ring_ccw" opcode: "reduce-scatter" - instruction_id: 53 - bytes_out: 10 + instruction_id: 90 + bytes_out: 40 communication_groups { - group_ids: 6 + group_ids: 3 group_ids: 2 } - operand_ids: 43 + operand_ids: 80 inner_subroutines { - name: "reduce-scatter_dim-1_bidir-ring_ccw_unidir-ring" - subroutine_root_id: 55 + name: "reduce-scatter_conc_bidir-ring_ccw_unidir-ring" + subroutine_root_id: 92 execution_probability: 1 execution_count: 1 instructions { - name: "reduce-scatter_dim-1_bidir-ring_ccw_unidir-ring_sendrecv_1" + name: "reduce-scatter_conc_bidir-ring_ccw_unidir-ring_sendrecv_1" opcode: "sendrecv" - instruction_id: 54 - bytes_in: 5 - bytes_out: 5 + instruction_id: 91 + bytes_in: 20 + bytes_out: 20 communication_groups { - group_ids: 6 - group_ids: 6 + group_ids: 3 + group_ids: 3 } } instructions { - name: "reduce-scatter_dim-1_bidir-ring_ccw_unidir-ring_reduction_1" + name: "reduce-scatter_conc_bidir-ring_ccw_unidir-ring_reduction_1" opcode: "call" - instruction_id: 55 - operand_ids: 54 + instruction_id: 92 + operand_ids: 91 inner_subroutines { name: "reduction_subroutine_phase_1" - subroutine_root_id: 58 + subroutine_root_id: 95 execution_probability: 1 execution_count: 1 instructions { name: "op1_phase_1" opcode: "delay" - instruction_id: 56 - bytes_out: 5 + instruction_id: 93 + bytes_out: 20 } instructions { name: "op2_phase_1" opcode: "delay" - instruction_id: 57 - bytes_out: 5 + instruction_id: 94 + bytes_out: 20 } instructions { name: "sum_phase_1" opcode: "delay" - instruction_id: 58 - ops: 10 - operand_ids: 56 - operand_ids: 57 + instruction_id: 95 + ops: 40 + operand_ids: 93 + operand_ids: 94 } } } } } instructions { - name: "reduce-scatter_dim-1_bidir-ring_root_2" + name: "reduce-scatter_conc_bidir-ring_root_2" opcode: "null" - instruction_id: 59 - operand_ids: 47 - operand_ids: 53 + instruction_id: 96 + operand_ids: 84 + operand_ids: 90 } } } } - )proto"; - google::protobuf::TextFormat::ParseFromString(reducescatter_str, - &reducescatter_proto); - EXPECT_TRUE(google::protobuf::util::MessageDifferencer::Equals( - reducescatter->ToProto().value(), reducescatter_proto)); -} + )proto"; + google::protobuf::TextFormat::ParseFromString(test_str, + &proto); + return proto; +} // NOLINT -// Tests expanding 1D-Torus reduce-scatter -TEST(Torus2dReduceScatter, InconsecutiveProcessors) { +// Tests expanding 1D-Torus reduce-scatter with barrier +TEST(Torus2dReduceScatter, WithBarrier) { auto graph = absl::make_unique("test_graph", 2); auto sub = absl::make_unique( "test_subroutine", graph.get()); auto sub_ptr = sub.get(); + sub_ptr->SetId(3); graph->SetEntrySubroutine(std::move(sub)); ASSERT_OK_AND_ASSIGN(auto instr_1, paragraph::Instruction::Create( @@ -1173,8 +2377,8 @@ TEST(Torus2dReduceScatter, InconsecutiveProcessors) { ASSERT_OK_AND_ASSIGN(auto reducescatter, paragraph::Instruction::Create( paragraph::Opcode::kReduceScatter, "reduce-scatter", sub_ptr)); - reducescatter->SetBytesOut(48); - paragraph::CommunicationGroup reducescatter_group = {0, 2, 4}; + reducescatter->SetBytesOut(80); + paragraph::CommunicationGroup reducescatter_group = {0, 1, 2, 3, 4, 5, 6, 7}; reducescatter->AppendCommunicationGroup(reducescatter_group); auto reduction_sub = absl::make_unique( @@ -1182,13 +2386,13 @@ TEST(Torus2dReduceScatter, InconsecutiveProcessors) { auto reduction_ptr = reduction_sub.get(); ASSERT_OK_AND_ASSIGN(auto op1, paragraph::Instruction::Create( paragraph::Opcode::kDelay, "op1", reduction_ptr)); - op1->SetBytesOut(48); + op1->SetBytesOut(80); ASSERT_OK_AND_ASSIGN(auto op2, paragraph::Instruction::Create( paragraph::Opcode::kDelay, "op2", reduction_ptr)); - op2->SetBytesOut(48); + op2->SetBytesOut(80); ASSERT_OK_AND_ASSIGN(auto sum_op, paragraph::Instruction::Create( paragraph::Opcode::kDelay, "sum", reduction_ptr, true)); - sum_op->SetOps(96); + sum_op->SetOps(160); sum_op->AddOperand(op1); sum_op->AddOperand(op2); reducescatter->AppendInnerSubroutine(std::move(reduction_sub)); @@ -1201,7 +2405,11 @@ TEST(Torus2dReduceScatter, InconsecutiveProcessors) { { "reduce-scatter": { "algorithm": "torus-2d", - "dimension_widths": [2, 3] + "concentration": 2, + "dimension_widths": [2, 2], + "barrier": { + "algorithm": "centralized" + } } } )"_json; @@ -1210,8 +2418,14 @@ TEST(Torus2dReduceScatter, InconsecutiveProcessors) { paragraph::TranslatorType::kCollective, config)); EXPECT_OK(translators["reduce-scatter"]->Translate(reducescatter)); - paragraph::InstructionProto reducescatter_proto; - std::string reducescatter_str = + paragraph::InstructionProto reducescatter_proto = with_barrier_test_proto(); + EXPECT_TRUE(google::protobuf::util::MessageDifferencer::Equals( + reducescatter->ToProto().value(), reducescatter_proto)); +} + +paragraph::InstructionProto inconsecutive_proc_test_proto() { + paragraph::InstructionProto proto; + std::string test_str = R"proto( name: "reduce-scatter" opcode: "reduce-scatter" @@ -1224,11 +2438,11 @@ communication_groups { } inner_subroutines { name: "reduce-scatter_torus-2d" - subroutine_root_id: 7 + subroutine_root_id: 56 execution_probability: 1 execution_count: 1 instructions { - name: "reduce-scatter_dim-1" + name: "reduce-scatter_stage-0_dim-1" opcode: "reduce-scatter" instruction_id: 7 bytes_out: 48 @@ -1238,12 +2452,12 @@ inner_subroutines { group_ids: 4 } inner_subroutines { - name: "reduce-scatter_dim-1_bidir-ring" + name: "reduce-scatter_stage-0_dim-1_bidir-ring" subroutine_root_id: 30 execution_probability: 1 execution_count: 1 instructions { - name: "reduce-scatter_dim-1_bidir-ring_cw" + name: "reduce-scatter_stage-0_dim-1_bidir-ring_cw" opcode: "reduce-scatter" instruction_id: 8 bytes_out: 24 @@ -1253,12 +2467,12 @@ inner_subroutines { group_ids: 4 } inner_subroutines { - name: "reduce-scatter_dim-1_bidir-ring_cw_unidir-ring" + name: "reduce-scatter_stage-0_dim-1_bidir-ring_cw_unidir-ring" subroutine_root_id: 15 execution_probability: 1 execution_count: 1 instructions { - name: "reduce-scatter_dim-1_bidir-ring_cw_unidir-ring_sendrecv_1" + name: "reduce-scatter_stage-0_dim-1_bidir-ring_cw_unidir-ring_sendrecv_1" opcode: "sendrecv" instruction_id: 9 bytes_in: 8 @@ -1269,7 +2483,7 @@ inner_subroutines { } } instructions { - name: "reduce-scatter_dim-1_bidir-ring_cw_unidir-ring_reduction_1" + name: "reduce-scatter_stage-0_dim-1_bidir-ring_cw_unidir-ring_reduction_1" opcode: "call" instruction_id: 10 operand_ids: 9 @@ -1301,7 +2515,7 @@ inner_subroutines { } } instructions { - name: "reduce-scatter_dim-1_bidir-ring_cw_unidir-ring_sendrecv_2" + name: "reduce-scatter_stage-0_dim-1_bidir-ring_cw_unidir-ring_sendrecv_2" opcode: "sendrecv" instruction_id: 14 bytes_in: 8 @@ -1313,7 +2527,7 @@ inner_subroutines { operand_ids: 10 } instructions { - name: "reduce-scatter_dim-1_bidir-ring_cw_unidir-ring_reduction_2" + name: "reduce-scatter_stage-0_dim-1_bidir-ring_cw_unidir-ring_reduction_2" opcode: "call" instruction_id: 15 operand_ids: 14 @@ -1347,7 +2561,7 @@ inner_subroutines { } } instructions { - name: "reduce-scatter_dim-1_bidir-ring_ccw" + name: "reduce-scatter_stage-0_dim-1_bidir-ring_ccw" opcode: "reduce-scatter" instruction_id: 19 bytes_out: 24 @@ -1357,12 +2571,12 @@ inner_subroutines { group_ids: 0 } inner_subroutines { - name: "reduce-scatter_dim-1_bidir-ring_ccw_unidir-ring" + name: "reduce-scatter_stage-0_dim-1_bidir-ring_ccw_unidir-ring" subroutine_root_id: 26 execution_probability: 1 execution_count: 1 instructions { - name: "reduce-scatter_dim-1_bidir-ring_ccw_unidir-ring_sendrecv_1" + name: "reduce-scatter_stage-0_dim-1_bidir-ring_ccw_unidir-ring_sendrecv_1" opcode: "sendrecv" instruction_id: 20 bytes_in: 8 @@ -1373,7 +2587,7 @@ inner_subroutines { } } instructions { - name: "reduce-scatter_dim-1_bidir-ring_ccw_unidir-ring_reduction_1" + name: "reduce-scatter_stage-0_dim-1_bidir-ring_ccw_unidir-ring_reduction_1" opcode: "call" instruction_id: 21 operand_ids: 20 @@ -1405,7 +2619,7 @@ inner_subroutines { } } instructions { - name: "reduce-scatter_dim-1_bidir-ring_ccw_unidir-ring_sendrecv_2" + name: "reduce-scatter_stage-0_dim-1_bidir-ring_ccw_unidir-ring_sendrecv_2" opcode: "sendrecv" instruction_id: 25 bytes_in: 8 @@ -1417,7 +2631,7 @@ inner_subroutines { operand_ids: 21 } instructions { - name: "reduce-scatter_dim-1_bidir-ring_ccw_unidir-ring_reduction_2" + name: "reduce-scatter_stage-0_dim-1_bidir-ring_ccw_unidir-ring_reduction_2" opcode: "call" instruction_id: 26 operand_ids: 25 @@ -1451,7 +2665,7 @@ inner_subroutines { } } instructions { - name: "reduce-scatter_dim-1_bidir-ring_root_2" + name: "reduce-scatter_stage-0_dim-1_bidir-ring_root_2" opcode: "null" instruction_id: 30 operand_ids: 8 @@ -1459,10 +2673,312 @@ inner_subroutines { } } } + instructions { + name: "reduce-scatter_stage-0_root" + opcode: "null" + instruction_id: 31 + operand_ids: 7 + } + instructions { + name: "reduce-scatter_stage-1_dim-1" + opcode: "reduce-scatter" + instruction_id: 32 + bytes_out: 144 + communication_groups { + group_ids: 0 + group_ids: 2 + group_ids: 4 + } + operand_ids: 31 + inner_subroutines { + name: "reduce-scatter_stage-1_dim-1_bidir-ring" + subroutine_root_id: 55 + execution_probability: 1 + execution_count: 1 + instructions { + name: "reduce-scatter_stage-1_dim-1_bidir-ring_cw" + opcode: "reduce-scatter" + instruction_id: 33 + bytes_out: 72 + communication_groups { + group_ids: 0 + group_ids: 2 + group_ids: 4 + } + inner_subroutines { + name: "reduce-scatter_stage-1_dim-1_bidir-ring_cw_unidir-ring" + subroutine_root_id: 40 + execution_probability: 1 + execution_count: 1 + instructions { + name: "reduce-scatter_stage-1_dim-1_bidir-ring_cw_unidir-ring_sendrecv_1" + opcode: "sendrecv" + instruction_id: 34 + bytes_in: 24 + bytes_out: 24 + communication_groups { + group_ids: 0 + group_ids: 4 + } + } + instructions { + name: "reduce-scatter_stage-1_dim-1_bidir-ring_cw_unidir-ring_reduction_1" + opcode: "call" + instruction_id: 35 + operand_ids: 34 + inner_subroutines { + name: "reduction_subroutine_phase_1" + subroutine_root_id: 38 + execution_probability: 1 + execution_count: 1 + instructions { + name: "op1_phase_1" + opcode: "delay" + instruction_id: 36 + bytes_out: 24 + } + instructions { + name: "op2_phase_1" + opcode: "delay" + instruction_id: 37 + bytes_out: 24 + } + instructions { + name: "sum_phase_1" + opcode: "delay" + instruction_id: 38 + ops: 48 + operand_ids: 36 + operand_ids: 37 + } + } + } + instructions { + name: "reduce-scatter_stage-1_dim-1_bidir-ring_cw_unidir-ring_sendrecv_2" + opcode: "sendrecv" + instruction_id: 39 + bytes_in: 24 + bytes_out: 24 + communication_groups { + group_ids: 0 + group_ids: 4 + } + operand_ids: 35 + } + instructions { + name: "reduce-scatter_stage-1_dim-1_bidir-ring_cw_unidir-ring_reduction_2" + opcode: "call" + instruction_id: 40 + operand_ids: 39 + inner_subroutines { + name: "reduction_subroutine_phase_2" + subroutine_root_id: 43 + execution_probability: 1 + execution_count: 1 + instructions { + name: "op1_phase_2" + opcode: "delay" + instruction_id: 41 + bytes_out: 24 + } + instructions { + name: "op2_phase_2" + opcode: "delay" + instruction_id: 42 + bytes_out: 24 + } + instructions { + name: "sum_phase_2" + opcode: "delay" + instruction_id: 43 + ops: 48 + operand_ids: 41 + operand_ids: 42 + } + } + } + } + } + instructions { + name: "reduce-scatter_stage-1_dim-1_bidir-ring_ccw" + opcode: "reduce-scatter" + instruction_id: 44 + bytes_out: 72 + communication_groups { + group_ids: 4 + group_ids: 2 + group_ids: 0 + } + inner_subroutines { + name: "reduce-scatter_stage-1_dim-1_bidir-ring_ccw_unidir-ring" + subroutine_root_id: 51 + execution_probability: 1 + execution_count: 1 + instructions { + name: "reduce-scatter_stage-1_dim-1_bidir-ring_ccw_unidir-ring_sendrecv_1" + opcode: "sendrecv" + instruction_id: 45 + bytes_in: 24 + bytes_out: 24 + communication_groups { + group_ids: 4 + group_ids: 0 + } + } + instructions { + name: "reduce-scatter_stage-1_dim-1_bidir-ring_ccw_unidir-ring_reduction_1" + opcode: "call" + instruction_id: 46 + operand_ids: 45 + inner_subroutines { + name: "reduction_subroutine_phase_1" + subroutine_root_id: 49 + execution_probability: 1 + execution_count: 1 + instructions { + name: "op1_phase_1" + opcode: "delay" + instruction_id: 47 + bytes_out: 24 + } + instructions { + name: "op2_phase_1" + opcode: "delay" + instruction_id: 48 + bytes_out: 24 + } + instructions { + name: "sum_phase_1" + opcode: "delay" + instruction_id: 49 + ops: 48 + operand_ids: 47 + operand_ids: 48 + } + } + } + instructions { + name: "reduce-scatter_stage-1_dim-1_bidir-ring_ccw_unidir-ring_sendrecv_2" + opcode: "sendrecv" + instruction_id: 50 + bytes_in: 24 + bytes_out: 24 + communication_groups { + group_ids: 4 + group_ids: 0 + } + operand_ids: 46 + } + instructions { + name: "reduce-scatter_stage-1_dim-1_bidir-ring_ccw_unidir-ring_reduction_2" + opcode: "call" + instruction_id: 51 + operand_ids: 50 + inner_subroutines { + name: "reduction_subroutine_phase_2" + subroutine_root_id: 54 + execution_probability: 1 + execution_count: 1 + instructions { + name: "op1_phase_2" + opcode: "delay" + instruction_id: 52 + bytes_out: 24 + } + instructions { + name: "op2_phase_2" + opcode: "delay" + instruction_id: 53 + bytes_out: 24 + } + instructions { + name: "sum_phase_2" + opcode: "delay" + instruction_id: 54 + ops: 48 + operand_ids: 52 + operand_ids: 53 + } + } + } + } + } + instructions { + name: "reduce-scatter_stage-1_dim-1_bidir-ring_root_2" + opcode: "null" + instruction_id: 55 + operand_ids: 33 + operand_ids: 44 + } + } + } + instructions { + name: "reduce-scatter_stage-1_root" + opcode: "null" + instruction_id: 56 + operand_ids: 32 + } } - )proto"; - google::protobuf::TextFormat::ParseFromString(reducescatter_str, - &reducescatter_proto); + )proto"; + google::protobuf::TextFormat::ParseFromString(test_str, + &proto); + return proto; +} // NOLINT + +// Tests expanding 1D-Torus reduce-scatter +TEST(Torus2dReduceScatter, InconsecutiveProcessors) { + auto graph = absl::make_unique("test_graph", 2); + auto sub = absl::make_unique( + "test_subroutine", graph.get()); + auto sub_ptr = sub.get(); + graph->SetEntrySubroutine(std::move(sub)); + + ASSERT_OK_AND_ASSIGN(auto instr_1, paragraph::Instruction::Create( + paragraph::Opcode::kDelay, "first_instruction", sub_ptr)); + instr_1->SetOps(4); + + ASSERT_OK_AND_ASSIGN(auto reducescatter, + paragraph::Instruction::Create( + paragraph::Opcode::kReduceScatter, "reduce-scatter", sub_ptr)); + reducescatter->SetBytesOut(48); + paragraph::CommunicationGroup reducescatter_group = {0, 2, 4}; + reducescatter->AppendCommunicationGroup(reducescatter_group); + + auto reduction_sub = absl::make_unique( + "reduction_subroutine", graph.get()); + auto reduction_ptr = reduction_sub.get(); + ASSERT_OK_AND_ASSIGN(auto op1, paragraph::Instruction::Create( + paragraph::Opcode::kDelay, "op1", reduction_ptr)); + op1->SetBytesOut(48); + ASSERT_OK_AND_ASSIGN(auto op2, paragraph::Instruction::Create( + paragraph::Opcode::kDelay, "op2", reduction_ptr)); + op2->SetBytesOut(48); + ASSERT_OK_AND_ASSIGN(auto sum_op, paragraph::Instruction::Create( + paragraph::Opcode::kDelay, "sum", reduction_ptr, true)); + sum_op->SetOps(96); + sum_op->AddOperand(op1); + sum_op->AddOperand(op2); + reducescatter->AppendInnerSubroutine(std::move(reduction_sub)); + + ASSERT_OK_AND_ASSIGN(auto instr_3, paragraph::Instruction::Create( + paragraph::Opcode::kDelay, "last_instruction", sub_ptr, true)); + instr_3->SetOps(4); + + nlohmann::json config = R"( + { + "reduce-scatter": { + "algorithm": "torus-2d", + "dimension_widths": [2, 3] + } + } + )"_json; + + ASSERT_OK_AND_ASSIGN(auto translators, paragraph::CreateTranslators( + paragraph::TranslatorType::kCollective, config)); + EXPECT_OK(translators["reduce-scatter"]->Translate(reducescatter)); + + paragraph::InstructionProto reducescatter_proto = + inconsecutive_proc_test_proto(); EXPECT_TRUE(google::protobuf::util::MessageDifferencer::Equals( reducescatter->ToProto().value(), reducescatter_proto)); } diff --git a/paragraph/translation/utils.cc b/paragraph/translation/utils.cc index f0fe768..3ed85d9 100644 --- a/paragraph/translation/utils.cc +++ b/paragraph/translation/utils.cc @@ -15,6 +15,7 @@ #include "paragraph/translation/utils.h" #include +#include namespace paragraph { @@ -45,6 +46,67 @@ uint64_t GridCoordinatesToConsecutiveProcessorId( return processor_id; } +CommunicationGroup CommunicationGroupLocalProjection( + int64_t processor_id, + const CommunicationGroup& comm_group, + const std::vector& dimension_sizes, + uint64_t concentration) { + std::vector processor_coordinates; + std::unordered_set whole_world(comm_group.begin(), comm_group.end()); + // Check if we have non-trivial concentration first and need to perform + // explicit local exchange step + CommunicationGroup new_comm_group; + if ((concentration > 1)) { + processor_coordinates = ConsecutiveProcessorIdToGridCoordinates( + processor_id, dimension_sizes, concentration); + for (uint64_t i = 0; i < concentration; i++) { + processor_coordinates.at(0) = i; + uint64_t new_processor_id = GridCoordinatesToConsecutiveProcessorId( + processor_coordinates, dimension_sizes, concentration); + if (whole_world.find(new_processor_id) != whole_world.end()) { + new_comm_group.push_back(new_processor_id); + } + } + } + return new_comm_group; +} + +CommunicationGroup CommunicationGroupProjectionOnGrid( + int64_t processor_id, + const CommunicationGroup& comm_group, + size_t dimension, + bool include_concentrators, + const std::vector& dimension_sizes, + uint64_t concentration) { + std::vector processor_coordinates = + ConsecutiveProcessorIdToGridCoordinates(processor_id, + dimension_sizes, + concentration); + std::unordered_set whole_world(comm_group.begin(), comm_group.end()); + CommunicationGroup new_comm_group; + uint64_t dim_width = dimension_sizes.at(dimension); + for (uint64_t i = 0; i < dim_width; i++) { + processor_coordinates.at(dimension + 1) = i; + if (include_concentrators) { + for (uint64_t j = 0; j < concentration; j++) { + processor_coordinates.at(0) = j; + uint64_t new_processor_id = GridCoordinatesToConsecutiveProcessorId( + processor_coordinates, dimension_sizes, concentration); + if (whole_world.find(new_processor_id) != whole_world.end()) { + new_comm_group.push_back(new_processor_id); + } + } + } else { + uint64_t new_processor_id = GridCoordinatesToConsecutiveProcessorId( + processor_coordinates, dimension_sizes, concentration); + if (whole_world.find(new_processor_id) != whole_world.end()) { + new_comm_group.push_back(new_processor_id); + } + } + } + return new_comm_group; +} + CommunicationGroup Swizzling2dGridToRing( const std::vector& dimension_sizes, uint64_t concentration) { diff --git a/paragraph/translation/utils.h b/paragraph/translation/utils.h index e037cc3..367e466 100644 --- a/paragraph/translation/utils.h +++ b/paragraph/translation/utils.h @@ -33,6 +33,20 @@ uint64_t GridCoordinatesToConsecutiveProcessorId( const std::vector& dimension_sizes, uint64_t concentration); +CommunicationGroup CommunicationGroupLocalProjection( + int64_t processor_id, + const CommunicationGroup& comm_group, + const std::vector& dimension_sizes, + uint64_t concentration); + +CommunicationGroup CommunicationGroupProjectionOnGrid( + int64_t processor_id, + const CommunicationGroup& comm_group, + size_t dimension, + bool include_concentrators, + const std::vector& dimension_sizes, + uint64_t concentration); + // 2D swizzling algorithm that produces Hamiltonian cycle through all the // vertices of 2D Grid, such as Mesh or Torus. It is used to map these // topologies onto logical ring topology. This is not the optimal algorithm as diff --git a/paragraph/translation/utils_test.cc b/paragraph/translation/utils_test.cc index c6a5430..5f47cc9 100644 --- a/paragraph/translation/utils_test.cc +++ b/paragraph/translation/utils_test.cc @@ -53,6 +53,37 @@ TEST(TranslationUtils, GridCoordinatesToConsecutiveProcessorId) { 22); } +// Tests Communication group intersection with local processors +TEST(TranslationUtils, CommunicationGroupLocalProjection) { + uint64_t concentration = 2; + std::vector dimension_sizes = {4, 3}; + paragraph::CommunicationGroup comm_group = {0, 1, 2, 3, 4, 5, 11, 12}; + paragraph::CommunicationGroup test_group = {2, 3}; + EXPECT_EQ(paragraph::CommunicationGroupLocalProjection( + 3, comm_group, dimension_sizes, concentration), + test_group); + paragraph::CommunicationGroup test_group_2; + EXPECT_EQ(paragraph::CommunicationGroupLocalProjection( + 3, {0, 1}, dimension_sizes, concentration), + test_group_2); +} + +// Tests Communication group intersection with processors in particular +// dimensions +TEST(TranslationUtils, CommunicationGroupProjectionOnGrid) { + uint64_t concentration = 2; + std::vector dimension_sizes = {2, 3}; + paragraph::CommunicationGroup comm_group = {0, 1, 2, 3, 4, 5, 10, 11}; + paragraph::CommunicationGroup test_group = {2, 3, 10, 11}; + EXPECT_EQ(paragraph::CommunicationGroupProjectionOnGrid( + 3, comm_group, 1, true, dimension_sizes, concentration), + test_group); + paragraph::CommunicationGroup test_group_2 = {1, 3}; + EXPECT_EQ(paragraph::CommunicationGroupProjectionOnGrid( + 1, comm_group, 0, false, dimension_sizes, concentration), + test_group_2); +} + // Tests 2d swizzling to map 2D grid on a logical ring TEST(TranslationUtils, Swizzling2dGridToRing) { uint64_t concentration = 2;