Shamrock 2025.10.0
Astrophysical Code
Loading...
Searching...
No Matches
HilbertLoadBalance.cpp
Go to the documentation of this file.
1// -------------------------------------------------------//
2//
3// SHAMROCK code for hydrodynamics
4// Copyright (c) 2021-2026 Timothée David--Cléris <tim.shamrock@proton.me>
5// SPDX-License-Identifier: CeCILL Free Software License Agreement v2.1
6// Shamrock is licensed under the CeCILL 2.1 License, see LICENSE for more information
7//
8// -------------------------------------------------------//
9
18#include "shambase/string.hpp"
19#include "shambackends/vec.hpp"
25
26inline void apply_node_patch_packing(
27 std::vector<shamrock::patch::Patch> &global_patch_list, std::vector<i32> &new_owner_table) {
28
29 // Note that there seems to be a data race here
30 // However this should never happends as packing index will only point toward a patch without
31 // packing. As such the data we are accessing should never be modified during this loop.
32#pragma omp parallel for
33 for (size_t i = 0; i < global_patch_list.size(); i++) {
34 if (global_patch_list[i].pack_node_index != u64_max) {
35 new_owner_table[i] = new_owner_table[global_patch_list[i].pack_node_index];
36 }
37 }
38}
39
40namespace shamrock::scheduler {
41
42 template<class T>
44 template<class T>
46 template<class T>
48
49 template<class hilbert_num>
51 std::vector<shamrock::patch::Patch> &global_patch_list) {
52
53 StackEntry stack_loc{};
54 using namespace shamrock::patch;
55
56 // result
57 LoadBalancingChangeList change_list;
58
59 using Torder = hilbert_num;
60 using Tweight = u64;
61 using LBTile = TileWithLoad<Torder, Tweight>;
62
63 // generate hilbert code, load value, and index before sort
64 std::vector<LBTile> patch_dt(global_patch_list.size());
65
66#pragma omp parallel for
67 for (u64 i = 0; i < global_patch_list.size(); i++) {
68
69 const Patch &p = global_patch_list[i];
70
71 patch_dt[i]
72 = {SFC::icoord_to_hilbert(p.coord_min[0], p.coord_min[1], p.coord_min[2]),
73 p.load_value};
74 }
75
76 std::vector<i32> new_owner_table = load_balance(std::move(patch_dt));
77
78 // apply patch packing in same node for merge
79 apply_node_patch_packing(global_patch_list, new_owner_table);
80
81 // make change list
82 {
83 std::vector<u64> load_per_node(shamcomm::world_size());
84
85 std::vector<i32> tags_it_node(shamcomm::world_size());
86 for (u64 i = 0; i < global_patch_list.size(); i++) {
87
88 i32 old_owner = global_patch_list[i].node_owner_id;
89 i32 new_owner = new_owner_table[i];
90
91 // TODO add bool for optional print verbosity
92 // std::cout << i << " : " << old_owner << " -> " << new_owner << std::endl;
93 if (new_owner != old_owner) {
94
95 using ChangeOp = LoadBalancingChangeList::ChangeOp;
96
97 ChangeOp op;
98 op.patch_idx = i;
99 op.patch_id = global_patch_list[i].id_patch;
100 op.rank_owner_new = new_owner;
101 op.rank_owner_old = old_owner;
102 op.tag_comm = tags_it_node[old_owner];
103
104 change_list.change_ops.push_back(op);
105 tags_it_node[old_owner]++;
106 }
107
108 load_per_node[new_owner_table[i]] += global_patch_list[i].load_value;
109 }
110
111 // shamlog_debug_ln("HilbertLoadBalance", "loads after balancing");
114 f64 avg = 0;
115 f64 var = 0;
116
117 i32 world_size = shamcomm::world_size();
118
119#pragma omp parallel for reduction(min : min) reduction(max : max) reduction(+ : avg)
120 for (i32 nid = 0; nid < world_size; nid++) {
121 f64 val = load_per_node[nid];
122 min = sycl::fmin(min, val);
123 max = sycl::fmax(max, val);
124 avg += val;
125 }
126
127 if (shamcomm::world_rank() == 0
128 && shamcomm::logs::get_loglevel() >= shamcomm::logs::log_debug) {
129 for (i32 nid = 0; nid < world_size; nid++) {
130 shamlog_debug_ln(
131 "HilbertLoadBalance", "node :", nid, "load :", load_per_node[nid]);
132 }
133 }
134 avg /= world_size;
135
136#pragma omp parallel for reduction(+ : var)
137 for (i32 nid = 0; nid < world_size; nid++) {
138 f64 val = load_per_node[nid];
139 var += (val - avg) * (val - avg);
140 }
141 var /= world_size;
142
143 if (shamcomm::world_rank() == 0) {
144 std::string str = "Loadbalance stats : \n";
145 str += shambase::format(" npatch = {}\n", global_patch_list.size());
146 str += shambase::format(" min = {}\n", min);
147 str += shambase::format(" max = {}\n", max);
148 str += shambase::format(" avg = {}\n", avg);
149 if (max == 0) {
150 str += " efficiency = ???%";
151 } else {
152 str += shambase::format(
153 " efficiency = {:.2f}%", 100 - (100 * (max - min) / max));
154 }
155 logger::info_ln("LoadBalance", str);
156 }
157 }
158
159 return change_list;
160 }
161
162 template class HilbertLoadBalance<u64>;
164
165} // namespace shamrock::scheduler
function to run load balancing with the hilbert curve
implementation of the hilbert curve load balancing
std::vector< i32 > load_balance(std::vector< TileWithLoad< Torder, Tweight > > &&lb_vector, i32 world_size=shamcomm::world_size())
load balance the input vector
Header file describing a Node Instance.
double f64
Alias for double.
std::uint64_t u64
64 bit unsigned integer
std::int32_t i32
32 bit integer
static LoadBalancingChangeList make_change_list(std::vector< shamrock::patch::Patch > &global_patch_list)
generate the change list from the list of patch to run the load balancing
i8 get_loglevel()
Get the current global log level.
Definition loglevel.hpp:52
i32 world_rank()
Gives the rank of the current process in the MPI communicator.
Definition worldInfo.cpp:40
i32 world_size()
Gives the size of the MPI communicator.
Definition worldInfo.cpp:38
constexpr u64 u64_max
u64 max value
This file contains the definition for the stacktrace related functionality.
Patch object that contain generic patch information.
Definition Patch.hpp:33
header file to manage sycl