Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,67 @@ namespace shamrock::scheduler::details {
return new_owners;
}

template<class Torder, class Tweight>
inline std::vector<i32> lb_startegy_parallel_sweep2(
const std::vector<TileWithLoad<Torder, Tweight>> &lb_vector, i32 wsize) {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This function can crash if lb_vector is empty. An empty lb_vector will result in res being empty, and the access res[res.size() - 1] on line 192 will be out-of-bounds. Please add a check at the beginning of the function to handle this case.

For example:
if (lb_vector.empty()) {
return {};
}


using LBTile = TileWithLoad<Torder, Tweight>;
using LBTileResult = details::LoadBalancedTile<Torder, Tweight>;

std::vector<LBTileResult> res;
for (u64 i = 0; i < lb_vector.size(); i++) {
res.push_back(LBTileResult{lb_vector[i], i});
}

// apply the ordering
apply_ordering(res);

// compute increments for load
u64 accum = 0;
for (LBTileResult &tile : res) {
u64 cur_val = tile.load_value;
tile.accumulated_load_value = accum;
accum += cur_val;
Comment on lines +187 to +189

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The local variable cur_val is redundant. You can use tile.load_value directly in the accumulation to make the code slightly simpler and more direct.

Suggested change
u64 cur_val = tile.load_value;
tile.accumulated_load_value = accum;
accum += cur_val;
tile.accumulated_load_value = accum;
accum += tile.load_value;

}

double target_datacnt = double(res[res.size() - 1].accumulated_load_value) / wsize;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The calculation of target_datacnt is incorrect. It uses res[res.size() - 1].accumulated_load_value, which is the result of an exclusive scan, meaning it's the sum of all loads except for the last one. To get the total load, you should use the accum variable, which holds the total sum after the loop on lines 186-190 completes. I've also added a check for wsize > 0 to prevent a potential division by zero.

Suggested change
double target_datacnt = double(res[res.size() - 1].accumulated_load_value) / wsize;
double target_datacnt = (wsize > 0) ? (double(accum) / wsize) : 0.0;


for (LBTileResult &tile : res) {
tile.new_owner = (target_datacnt == 0)
? 0
: sycl::clamp(
i32((tile.accumulated_load_value / target_datacnt) + 0.5),
0,
wsize - 1);
}

if (shamcomm::world_rank() == 0) {
for (LBTileResult t : res) {
shamlog_debug_ln(
"HilbertLoadBalance",
t.ordering_val,
t.load_value,
t.accumulated_load_value,
t.index,
(target_datacnt == 0)
? 0
: sycl::clamp(
i32((t.accumulated_load_value / target_datacnt) + 0.5),
0,
i32(wsize) - 1),
(target_datacnt == 0) ? 0
: ((t.accumulated_load_value / target_datacnt) + 0.5));
}
Comment on lines +204 to +219

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The debug logging loop re-calculates the new_owner value, which has already been computed and stored in tile.new_owner. This is code duplication and can lead to inconsistencies if the calculation logic changes in one place but not the other. Additionally, the loop for (LBTileResult t : res) creates a copy of each element in res, which is inefficient. You can simplify the logging by using t.new_owner and iterate by const reference to avoid copies.

            for (const LBTileResult &t : res) {
                shamlog_debug_ln(
                    "HilbertLoadBalance",
                    t.ordering_val,
                    t.load_value,
                    t.accumulated_load_value,
                    t.index,
                    t.new_owner,
                    (target_datacnt == 0)
                        ? 0.0
                        : ((t.accumulated_load_value / target_datacnt) + 0.5));
            }

}

std::vector<i32> new_owners(res.size());
for (LBTileResult &tile : res) {
new_owners[tile.index] = tile.new_owner;
}

return new_owners;
}

struct LBMetric {
f64 min;
f64 max;
Expand Down Expand Up @@ -225,32 +286,63 @@ namespace shamrock::scheduler {
std::vector<TileWithLoad<Torder, Tweight>> &&lb_vector,
i32 world_size = shamcomm::world_size()) {

auto tmpres = details::lb_startegy_parallel_sweep(lb_vector, world_size);
auto metric_psweep = details::compute_LB_metric(lb_vector, tmpres, world_size);
std::vector<i32> res_best{};
details::LBMetric metric_best{
shambase::VectorProperties<f64>::get_inf(),
shambase::VectorProperties<f64>::get_inf(),
shambase::VectorProperties<f64>::get_inf(),
shambase::VectorProperties<f64>::get_inf()};
std::string strategy_best = "";

struct LBResult {
std::vector<i32> ranks;
details::LBMetric metric;
std::string strategy;
};

std::vector<LBResult> results;

{
auto tmpres = details::lb_startegy_parallel_sweep(lb_vector, world_size);
auto metric = details::compute_LB_metric(lb_vector, tmpres, world_size);
results.push_back(LBResult{tmpres, metric, "parallel sweep"});
}

auto tmpres_2 = details::lb_startegy_roundrobin(lb_vector, world_size);
auto metric_rrobin = details::compute_LB_metric(lb_vector, tmpres_2, world_size);
{
auto tmpres = details::lb_startegy_roundrobin(lb_vector, world_size);
auto metric = details::compute_LB_metric(lb_vector, tmpres, world_size);
results.push_back(LBResult{tmpres, metric, "round robin"});
}

{
auto tmpres = details::lb_startegy_parallel_sweep2(lb_vector, world_size);
auto metric = details::compute_LB_metric(lb_vector, tmpres, world_size);
results.push_back(LBResult{tmpres, metric, "parallel sweep 2"});
}
Comment on lines +305 to +321

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There is repeated code for running each load balancing strategy and storing the results. This can be refactored into a helper lambda to improve readability and maintainability, and to avoid potential vector copies when storing results.

        using strategy_fn = std::vector<i32>(
            const std::vector<TileWithLoad<Torder, Tweight>> &, i32);

        auto run_and_store_strategy = [&](strategy_fn *strategy, const std::string &name) {
            auto ranks  = strategy(lb_vector, world_size);
            auto metric = details::compute_LB_metric(lb_vector, ranks, world_size);
            results.emplace_back(LBResult{std::move(ranks), metric, name});
        };

        run_and_store_strategy(&details::lb_startegy_parallel_sweep, "parallel sweep");
        run_and_store_strategy(&details::lb_startegy_roundrobin, "round robin");
        run_and_store_strategy(&details::lb_startegy_parallel_sweep2, "parallel sweep 2");


if (metric_rrobin.max < metric_psweep.max) {
tmpres = tmpres_2;
for (const auto &result : results) {
if (shamcomm::world_rank() == 0) {
logger::info_ln(
"LoadBalance",
" - strategy \"",
result.strategy,
"\" : max =",
result.metric.max,
"min =",
result.metric.min);
}
if (result.metric.max < metric_best.max) {
metric_best = result.metric;
res_best = result.ranks;
strategy_best = result.strategy;
}
}
Comment on lines +323 to 339

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This loop for finding the best strategy can be made more concise and idiomatic by using std::min_element. This also separates the concern of logging from the logic of finding the best result, improving readability. The current implementation also copies the ranks vector every time a better strategy is found, which can be inefficient. Using std::min_element will find the best result, and then you can copy the data just once.

        for (const auto &result : results) {
            if (shamcomm::world_rank() == 0) {
                logger::info_ln(
                    "LoadBalance",
                    " - strategy \"",
                    result.strategy,
                    "\" : max =",
                    result.metric.max,
                    "min =",
                    result.metric.min);
            }
        }

        auto best_it = std::min_element(
            results.begin(),
            results.end(),
            [](const LBResult &a, const LBResult &b) { return a.metric.max < b.metric.max; });

        if (best_it != results.end()) {
            res_best      = best_it->ranks;
            strategy_best = best_it->strategy;
        }


if (shamcomm::world_rank() == 0) {
logger::info_ln("LoadBalance", "summary :");
logger::info_ln(
"LoadBalance",
" - strategy \"psweep\" : max =",
metric_psweep.max,
"min =",
metric_psweep.min);
logger::info_ln(
"LoadBalance",
" - strategy \"round robin\" : max =",
metric_rrobin.max,
"min =",
metric_rrobin.min);
logger::info_ln("LoadBalance", "best strategy :", strategy_best);
}
return tmpres;

return res_best;
}

} // namespace shamrock::scheduler
Loading