-
Notifications
You must be signed in to change notification settings - Fork 18
[Scheduler] improve psweep LB #1492
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -166,6 +166,67 @@ namespace shamrock::scheduler::details { | |||||||||||
| return new_owners; | ||||||||||||
| } | ||||||||||||
|
|
||||||||||||
| template<class Torder, class Tweight> | ||||||||||||
| inline std::vector<i32> lb_startegy_parallel_sweep2( | ||||||||||||
| const std::vector<TileWithLoad<Torder, Tweight>> &lb_vector, i32 wsize) { | ||||||||||||
|
|
||||||||||||
| using LBTile = TileWithLoad<Torder, Tweight>; | ||||||||||||
| using LBTileResult = details::LoadBalancedTile<Torder, Tweight>; | ||||||||||||
|
|
||||||||||||
| std::vector<LBTileResult> res; | ||||||||||||
| for (u64 i = 0; i < lb_vector.size(); i++) { | ||||||||||||
| res.push_back(LBTileResult{lb_vector[i], i}); | ||||||||||||
| } | ||||||||||||
|
|
||||||||||||
| // apply the ordering | ||||||||||||
| apply_ordering(res); | ||||||||||||
|
|
||||||||||||
| // compute increments for load | ||||||||||||
| u64 accum = 0; | ||||||||||||
| for (LBTileResult &tile : res) { | ||||||||||||
| u64 cur_val = tile.load_value; | ||||||||||||
| tile.accumulated_load_value = accum; | ||||||||||||
| accum += cur_val; | ||||||||||||
|
Comment on lines
+187
to
+189
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The local variable
Suggested change
|
||||||||||||
| } | ||||||||||||
|
|
||||||||||||
| double target_datacnt = double(res[res.size() - 1].accumulated_load_value) / wsize; | ||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The calculation of
Suggested change
|
||||||||||||
|
|
||||||||||||
| for (LBTileResult &tile : res) { | ||||||||||||
| tile.new_owner = (target_datacnt == 0) | ||||||||||||
| ? 0 | ||||||||||||
| : sycl::clamp( | ||||||||||||
| i32((tile.accumulated_load_value / target_datacnt) + 0.5), | ||||||||||||
| 0, | ||||||||||||
| wsize - 1); | ||||||||||||
| } | ||||||||||||
|
|
||||||||||||
| if (shamcomm::world_rank() == 0) { | ||||||||||||
| for (LBTileResult t : res) { | ||||||||||||
| shamlog_debug_ln( | ||||||||||||
| "HilbertLoadBalance", | ||||||||||||
| t.ordering_val, | ||||||||||||
| t.load_value, | ||||||||||||
| t.accumulated_load_value, | ||||||||||||
| t.index, | ||||||||||||
| (target_datacnt == 0) | ||||||||||||
| ? 0 | ||||||||||||
| : sycl::clamp( | ||||||||||||
| i32((t.accumulated_load_value / target_datacnt) + 0.5), | ||||||||||||
| 0, | ||||||||||||
| i32(wsize) - 1), | ||||||||||||
| (target_datacnt == 0) ? 0 | ||||||||||||
| : ((t.accumulated_load_value / target_datacnt) + 0.5)); | ||||||||||||
| } | ||||||||||||
|
Comment on lines
+204
to
+219
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The debug logging loop re-calculates the for (const LBTileResult &t : res) {
shamlog_debug_ln(
"HilbertLoadBalance",
t.ordering_val,
t.load_value,
t.accumulated_load_value,
t.index,
t.new_owner,
(target_datacnt == 0)
? 0.0
: ((t.accumulated_load_value / target_datacnt) + 0.5));
} |
||||||||||||
| } | ||||||||||||
|
|
||||||||||||
| std::vector<i32> new_owners(res.size()); | ||||||||||||
| for (LBTileResult &tile : res) { | ||||||||||||
| new_owners[tile.index] = tile.new_owner; | ||||||||||||
| } | ||||||||||||
|
|
||||||||||||
| return new_owners; | ||||||||||||
| } | ||||||||||||
|
|
||||||||||||
| struct LBMetric { | ||||||||||||
| f64 min; | ||||||||||||
| f64 max; | ||||||||||||
|
|
@@ -225,32 +286,63 @@ namespace shamrock::scheduler { | |||||||||||
| std::vector<TileWithLoad<Torder, Tweight>> &&lb_vector, | ||||||||||||
| i32 world_size = shamcomm::world_size()) { | ||||||||||||
|
|
||||||||||||
| auto tmpres = details::lb_startegy_parallel_sweep(lb_vector, world_size); | ||||||||||||
| auto metric_psweep = details::compute_LB_metric(lb_vector, tmpres, world_size); | ||||||||||||
| std::vector<i32> res_best{}; | ||||||||||||
| details::LBMetric metric_best{ | ||||||||||||
| shambase::VectorProperties<f64>::get_inf(), | ||||||||||||
| shambase::VectorProperties<f64>::get_inf(), | ||||||||||||
| shambase::VectorProperties<f64>::get_inf(), | ||||||||||||
| shambase::VectorProperties<f64>::get_inf()}; | ||||||||||||
| std::string strategy_best = ""; | ||||||||||||
|
|
||||||||||||
| struct LBResult { | ||||||||||||
| std::vector<i32> ranks; | ||||||||||||
| details::LBMetric metric; | ||||||||||||
| std::string strategy; | ||||||||||||
| }; | ||||||||||||
|
|
||||||||||||
| std::vector<LBResult> results; | ||||||||||||
|
|
||||||||||||
| { | ||||||||||||
| auto tmpres = details::lb_startegy_parallel_sweep(lb_vector, world_size); | ||||||||||||
| auto metric = details::compute_LB_metric(lb_vector, tmpres, world_size); | ||||||||||||
| results.push_back(LBResult{tmpres, metric, "parallel sweep"}); | ||||||||||||
| } | ||||||||||||
|
|
||||||||||||
| auto tmpres_2 = details::lb_startegy_roundrobin(lb_vector, world_size); | ||||||||||||
| auto metric_rrobin = details::compute_LB_metric(lb_vector, tmpres_2, world_size); | ||||||||||||
| { | ||||||||||||
| auto tmpres = details::lb_startegy_roundrobin(lb_vector, world_size); | ||||||||||||
| auto metric = details::compute_LB_metric(lb_vector, tmpres, world_size); | ||||||||||||
| results.push_back(LBResult{tmpres, metric, "round robin"}); | ||||||||||||
| } | ||||||||||||
|
|
||||||||||||
| { | ||||||||||||
| auto tmpres = details::lb_startegy_parallel_sweep2(lb_vector, world_size); | ||||||||||||
| auto metric = details::compute_LB_metric(lb_vector, tmpres, world_size); | ||||||||||||
| results.push_back(LBResult{tmpres, metric, "parallel sweep 2"}); | ||||||||||||
| } | ||||||||||||
|
Comment on lines
+305
to
+321
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There is repeated code for running each load balancing strategy and storing the results. This can be refactored into a helper lambda to improve readability and maintainability, and to avoid potential vector copies when storing results. using strategy_fn = std::vector<i32>(
const std::vector<TileWithLoad<Torder, Tweight>> &, i32);
auto run_and_store_strategy = [&](strategy_fn *strategy, const std::string &name) {
auto ranks = strategy(lb_vector, world_size);
auto metric = details::compute_LB_metric(lb_vector, ranks, world_size);
results.emplace_back(LBResult{std::move(ranks), metric, name});
};
run_and_store_strategy(&details::lb_startegy_parallel_sweep, "parallel sweep");
run_and_store_strategy(&details::lb_startegy_roundrobin, "round robin");
run_and_store_strategy(&details::lb_startegy_parallel_sweep2, "parallel sweep 2"); |
||||||||||||
|
|
||||||||||||
| if (metric_rrobin.max < metric_psweep.max) { | ||||||||||||
| tmpres = tmpres_2; | ||||||||||||
| for (const auto &result : results) { | ||||||||||||
| if (shamcomm::world_rank() == 0) { | ||||||||||||
| logger::info_ln( | ||||||||||||
| "LoadBalance", | ||||||||||||
| " - strategy \"", | ||||||||||||
| result.strategy, | ||||||||||||
| "\" : max =", | ||||||||||||
| result.metric.max, | ||||||||||||
| "min =", | ||||||||||||
| result.metric.min); | ||||||||||||
| } | ||||||||||||
| if (result.metric.max < metric_best.max) { | ||||||||||||
| metric_best = result.metric; | ||||||||||||
| res_best = result.ranks; | ||||||||||||
| strategy_best = result.strategy; | ||||||||||||
| } | ||||||||||||
| } | ||||||||||||
|
Comment on lines
+323
to
339
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This loop for finding the best strategy can be made more concise and idiomatic by using for (const auto &result : results) {
if (shamcomm::world_rank() == 0) {
logger::info_ln(
"LoadBalance",
" - strategy \"",
result.strategy,
"\" : max =",
result.metric.max,
"min =",
result.metric.min);
}
}
auto best_it = std::min_element(
results.begin(),
results.end(),
[](const LBResult &a, const LBResult &b) { return a.metric.max < b.metric.max; });
if (best_it != results.end()) {
res_best = best_it->ranks;
strategy_best = best_it->strategy;
} |
||||||||||||
|
|
||||||||||||
| if (shamcomm::world_rank() == 0) { | ||||||||||||
| logger::info_ln("LoadBalance", "summary :"); | ||||||||||||
| logger::info_ln( | ||||||||||||
| "LoadBalance", | ||||||||||||
| " - strategy \"psweep\" : max =", | ||||||||||||
| metric_psweep.max, | ||||||||||||
| "min =", | ||||||||||||
| metric_psweep.min); | ||||||||||||
| logger::info_ln( | ||||||||||||
| "LoadBalance", | ||||||||||||
| " - strategy \"round robin\" : max =", | ||||||||||||
| metric_rrobin.max, | ||||||||||||
| "min =", | ||||||||||||
| metric_rrobin.min); | ||||||||||||
| logger::info_ln("LoadBalance", "best strategy :", strategy_best); | ||||||||||||
| } | ||||||||||||
| return tmpres; | ||||||||||||
|
|
||||||||||||
| return res_best; | ||||||||||||
| } | ||||||||||||
|
|
||||||||||||
| } // namespace shamrock::scheduler | ||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function can crash if
lb_vectoris empty. An emptylb_vectorwill result inresbeing empty, and the accessres[res.size() - 1]on line 192 will be out-of-bounds. Please add a check at the beginning of the function to handle this case.For example:
if (lb_vector.empty()) {
return {};
}