Shamrock 2025.10.0
Astrophysical Code
Loading...
Searching...
No Matches
SchedulerPatchData.cpp
Go to the documentation of this file.
1// -------------------------------------------------------//
2//
3// SHAMROCK code for hydrodynamics
4// Copyright (c) 2021-2026 Timothée David--Cléris <tim.shamrock@proton.me>
5// SPDX-License-Identifier: CeCILL Free Software License Agreement v2.1
6// Shamrock is licensed under the CeCILL 2.1 License, see LICENSE for more information
7//
8// -------------------------------------------------------//
9
19#include "shambase/string.hpp"
27#include <stdexcept>
28#include <vector>
29
30#define NEW_LB_APPLY_IMPL
31
32#ifndef NEW_LB_APPLY_IMPL
34#endif
35
36namespace shamrock::scheduler {
37
38#ifdef NEW_LB_APPLY_IMPL
39 struct Message {
40 std::unique_ptr<shamcomm::CommunicationBuffer> buf;
41 i32 rank;
42 i32 tag;
43 };
44
45 void send_messages(std::vector<Message> &msgs, std::vector<MPI_Request> &rqs) {
47 for (auto &msg : msgs) {
48 rqs.push_back(MPI_Request{});
49 u32 rq_index = rqs.size() - 1;
50 auto &rq = rqs[rq_index];
51
52 u64 bsize = msg.buf->get_size();
53 if (bsize % 8 != 0) {
55 "the following mpi comm assume that we can send longs to pack 8byte");
56 }
57 u64 lcount = bsize / 8;
58 if (lcount > i32_max) {
59 shambase::throw_with_loc<std::runtime_error>("The message is too large for MPI");
60 }
61
63 msg.buf->get_ptr(),
65 get_mpi_type<u64>(),
66 msg.rank,
67 msg.tag,
68 MPI_COMM_WORLD,
69 &rq);
70 }
71 }
72
73 void recv_probe_messages(std::vector<Message> &msgs, std::vector<MPI_Request> &rqs) {
75 for (auto &msg : msgs) {
76 rqs.push_back(MPI_Request{});
77 u32 rq_index = rqs.size() - 1;
78 auto &rq = rqs[rq_index];
79
80 MPI_Status st;
81 i32 cnt;
82 shamcomm::mpi::Probe(msg.rank, msg.tag, MPI_COMM_WORLD, &st);
83 shamcomm::mpi::Get_count(&st, get_mpi_type<u64>(), &cnt);
84
85 msg.buf = std::make_unique<shamcomm::CommunicationBuffer>(
87 shamsys::instance::get_compute_scheduler_ptr());
88
90 msg.buf->get_ptr(),
91 cnt,
92 get_mpi_type<u64>(),
93 msg.rank,
94 msg.tag,
95 MPI_COMM_WORLD,
96 &rq);
97 }
98 }
99
102 SchedulerPatchList &patch_list) {
103
104 StackEntry stack_loc{};
105
107
108 auto serializer = [](shamrock::patch::PatchDataLayer &pdat) {
109 shamalgs::SerializeHelper ser(shamsys::instance::get_compute_scheduler_ptr());
110 ser.allocate(pdat.serialize_buf_byte_size());
111 pdat.serialize_buf(ser);
112 return ser.finalize();
113 };
114
115 auto deserializer = [&](sham::DeviceBuffer<u8> &&buf) {
116 // exchange the buffer held by the distrib data and give it to the serializer
118 shamsys::instance::get_compute_scheduler_ptr(),
119 std::forward<sham::DeviceBuffer<u8>>(buf));
120 return shamrock::patch::PatchDataLayer::deserialize_buf(ser, pdl_ptr);
121 };
122
123 std::vector<Message> send_payloads;
124 for (const ChangeOp op : change_list.change_ops) {
125 // if i'm sender
126 if (op.rank_owner_old == shamcomm::world_rank()) {
127 auto &patchdata = owned_data.get(op.patch_id);
128
129 sham::DeviceBuffer<u8> tmp = serializer(patchdata);
130
131 send_payloads.push_back(
132 Message{
133 std::make_unique<shamcomm::CommunicationBuffer>(
134 std::move(tmp), shamsys::instance::get_compute_scheduler_ptr()),
135 op.rank_owner_new,
136 op.tag_comm});
137 }
138 }
139
140 std::vector<MPI_Request> rqs;
141 send_messages(send_payloads, rqs);
142
143 std::vector<Message> recv_payloads;
144 for (const ChangeOp op : change_list.change_ops) {
145 auto &id_patch = op.patch_id;
146
147 // if i'm receiver
148 if (op.rank_owner_new == shamcomm::world_rank()) {
149 recv_payloads.push_back(
150 Message{
151 std::unique_ptr<shamcomm::CommunicationBuffer>{},
152 op.rank_owner_old,
153 op.tag_comm});
154 }
155 }
156
157 // receive
158 recv_probe_messages(recv_payloads, rqs);
159
160 std::vector<MPI_Status> st_lst(rqs.size());
161 shamcomm::mpi::Waitall(rqs.size(), rqs.data(), st_lst.data());
162
163 u32 idx = 0;
164 // receive
165 for (const ChangeOp op : change_list.change_ops) {
166 auto &id_patch = op.patch_id;
167
168 // if i'm receiver
169 if (op.rank_owner_new == shamcomm::world_rank()) {
170 Message &msg = recv_payloads[idx];
171
173
175 = shamcomm::CommunicationBuffer::convert_usm(std::move(comm_buf));
176
177 owned_data.add_obj(id_patch, deserializer(std::move(buf)));
178
179 idx++;
180 }
181 }
182
183 // erase old patchdata
184 for (const ChangeOp op : change_list.change_ops) {
185 auto &id_patch = op.patch_id;
186
187 patch_list.global[op.patch_idx].node_owner_id = op.rank_owner_new;
188
189 // if i'm sender delete old data
190 if (op.rank_owner_old == shamcomm::world_rank()) {
191 owned_data.erase(id_patch);
192 }
193 }
194 }
195#else
196
199 SchedulerPatchList &patch_list) {
200
201 StackEntry stack_loc{};
202
203 std::vector<PatchDataMpiRequest> rq_lst;
204
206
207 // send
208 for (const ChangeOp op : change_list.change_ops) { // switch to range based
209 // if i'm sender
210 if (op.rank_owner_old == shamcomm::world_rank()) {
211 auto &patchdata = owned_data.get(op.patch_id);
212 patchdata_isend(patchdata, rq_lst, op.rank_owner_new, op.tag_comm, MPI_COMM_WORLD);
213 }
214 }
215
216 // receive
217 for (const ChangeOp op : change_list.change_ops) {
218 auto &id_patch = op.patch_id;
219
220 // if i'm receiver
221 if (op.rank_owner_new == shamcomm::world_rank()) {
222 owned_data.add_obj(id_patch, pdl);
223 patchdata_irecv_probe(
224 owned_data.get(id_patch),
225 rq_lst,
226 op.rank_owner_old,
227 op.tag_comm,
228 MPI_COMM_WORLD);
229 }
230 }
231
232 waitall_pdat_mpi_rq(rq_lst);
233
234 // erase old patchdata
235 for (const ChangeOp op : change_list.change_ops) {
236 auto &id_patch = op.patch_id;
237
238 patch_list.global[op.patch_idx].node_owner_id = op.rank_owner_new;
239
240 // if i'm sender delete old data
241 if (op.rank_owner_old == shamcomm::world_rank()) {
242 owned_data.erase(id_patch);
243 }
244 }
245 }
246#endif
247
248 template<class Vectype>
249 void split_patchdata(
252 const std::array<shamrock::patch::Patch, 8> patches,
253 std::array<std::reference_wrapper<shamrock::patch::PatchDataLayer>, 8> pdats) {
254
255 using ptype = typename shambase::VectorProperties<Vectype>::component_type;
256
257 auto [bmin_p0, bmax_p0] = sim_box.patch_coord_to_domain<Vectype>(patches[0]);
258 auto [bmin_p1, bmax_p1] = sim_box.patch_coord_to_domain<Vectype>(patches[1]);
259 auto [bmin_p2, bmax_p2] = sim_box.patch_coord_to_domain<Vectype>(patches[2]);
260 auto [bmin_p3, bmax_p3] = sim_box.patch_coord_to_domain<Vectype>(patches[3]);
261 auto [bmin_p4, bmax_p4] = sim_box.patch_coord_to_domain<Vectype>(patches[4]);
262 auto [bmin_p5, bmax_p5] = sim_box.patch_coord_to_domain<Vectype>(patches[5]);
263 auto [bmin_p6, bmax_p6] = sim_box.patch_coord_to_domain<Vectype>(patches[6]);
264 auto [bmin_p7, bmax_p7] = sim_box.patch_coord_to_domain<Vectype>(patches[7]);
265
266 original_pd.split_patchdata<Vectype>(
267 pdats,
268 {bmin_p0, bmin_p1, bmin_p2, bmin_p3, bmin_p4, bmin_p5, bmin_p6, bmin_p7},
269 {bmax_p0, bmax_p1, bmax_p2, bmax_p3, bmax_p4, bmax_p5, bmax_p6, bmax_p7});
270 }
271
272 template void split_patchdata<f32_3>(
275 const std::array<shamrock::patch::Patch, 8> patches,
276 std::array<std::reference_wrapper<shamrock::patch::PatchDataLayer>, 8> pdats);
277
278 template void split_patchdata<f64_3>(
281 const std::array<shamrock::patch::Patch, 8> patches,
282 std::array<std::reference_wrapper<shamrock::patch::PatchDataLayer>, 8> pdats);
283
284 template void split_patchdata<u32_3>(
287 const std::array<shamrock::patch::Patch, 8> patches,
288 std::array<std::reference_wrapper<shamrock::patch::PatchDataLayer>, 8> pdats);
289
290 template void split_patchdata<u64_3>(
293 const std::array<shamrock::patch::Patch, 8> patches,
294 std::array<std::reference_wrapper<shamrock::patch::PatchDataLayer>, 8> pdats);
295
296 template void split_patchdata<i64_3>(
299 const std::array<shamrock::patch::Patch, 8> patches,
300 std::array<std::reference_wrapper<shamrock::patch::PatchDataLayer>, 8> pdats);
301
303 u64 key_orginal, const std::array<shamrock::patch::Patch, 8> patches) {
304
305 auto search = owned_data.find(key_orginal);
306
307 if (search != owned_data.not_found()) {
308
309 shamrock::patch::PatchDataLayer &original_pd = search->second;
310
319
320 auto &pdl = shambase::get_check_ref(pdl_ptr);
321
322 if (pdl.check_main_field_type<f32_3>()) {
323
324 shamrock::scheduler::split_patchdata<f32_3>(
325 original_pd, sim_box, patches, {pd0, pd1, pd2, pd3, pd4, pd5, pd6, pd7});
326 } else if (pdl.check_main_field_type<f64_3>()) {
327
328 shamrock::scheduler::split_patchdata<f64_3>(
329 original_pd, sim_box, patches, {pd0, pd1, pd2, pd3, pd4, pd5, pd6, pd7});
330 } else if (pdl.check_main_field_type<u32_3>()) {
331
332 shamrock::scheduler::split_patchdata<u32_3>(
333 original_pd, sim_box, patches, {pd0, pd1, pd2, pd3, pd4, pd5, pd6, pd7});
334 } else if (pdl.check_main_field_type<u64_3>()) {
335
336 shamrock::scheduler::split_patchdata<u64_3>(
337 original_pd, sim_box, patches, {pd0, pd1, pd2, pd3, pd4, pd5, pd6, pd7});
338 } else if (pdl.check_main_field_type<i64_3>()) {
339
340 shamrock::scheduler::split_patchdata<i64_3>(
341 original_pd, sim_box, patches, {pd0, pd1, pd2, pd3, pd4, pd5, pd6, pd7});
342 } else {
344 "the main field does not match any");
345 }
346
347 owned_data.erase(key_orginal);
348
349 owned_data.add_obj(patches[0].id_patch, std::move(pd0));
350 owned_data.add_obj(patches[1].id_patch, std::move(pd1));
351 owned_data.add_obj(patches[2].id_patch, std::move(pd2));
352 owned_data.add_obj(patches[3].id_patch, std::move(pd3));
353 owned_data.add_obj(patches[4].id_patch, std::move(pd4));
354 owned_data.add_obj(patches[5].id_patch, std::move(pd5));
355 owned_data.add_obj(patches[6].id_patch, std::move(pd6));
356 owned_data.add_obj(patches[7].id_patch, std::move(pd7));
357 }
358 }
359
360 void SchedulerPatchData::merge_patchdata(u64 new_key, const std::array<u64, 8> old_keys) {
361
362 auto search0 = owned_data.find(old_keys[0]);
363 auto search1 = owned_data.find(old_keys[1]);
364 auto search2 = owned_data.find(old_keys[2]);
365 auto search3 = owned_data.find(old_keys[3]);
366 auto search4 = owned_data.find(old_keys[4]);
367 auto search5 = owned_data.find(old_keys[5]);
368 auto search6 = owned_data.find(old_keys[6]);
369 auto search7 = owned_data.find(old_keys[7]);
370
371 if (search0 == owned_data.not_found()) {
373 "patchdata for key=%d was not owned by the node", old_keys[0]));
374 }
375 if (search1 == owned_data.not_found()) {
377 "patchdata for key=%d was not owned by the node", old_keys[1]));
378 }
379 if (search2 == owned_data.not_found()) {
381 "patchdata for key=%d was not owned by the node", old_keys[2]));
382 }
383 if (search3 == owned_data.not_found()) {
385 "patchdata for key=%d was not owned by the node", old_keys[3]));
386 }
387 if (search4 == owned_data.not_found()) {
389 "patchdata for key=%d was not owned by the node", old_keys[4]));
390 }
391 if (search5 == owned_data.not_found()) {
393 "patchdata for key=%d was not owned by the node", old_keys[5]));
394 }
395 if (search6 == owned_data.not_found()) {
397 "patchdata for key=%d was not owned by the node", old_keys[6]));
398 }
399 if (search7 == owned_data.not_found()) {
401 "patchdata for key=%d was not owned by the node", old_keys[7]));
402 }
403
404 shamrock::patch::PatchDataLayer new_pdat(pdl_ptr);
405
406 new_pdat.insert_elements(search0->second);
407 new_pdat.insert_elements(search1->second);
408 new_pdat.insert_elements(search2->second);
409 new_pdat.insert_elements(search3->second);
410 new_pdat.insert_elements(search4->second);
411 new_pdat.insert_elements(search5->second);
412 new_pdat.insert_elements(search6->second);
413 new_pdat.insert_elements(search7->second);
414
415 owned_data.erase(old_keys[0]);
416 owned_data.erase(old_keys[1]);
417 owned_data.erase(old_keys[2]);
418 owned_data.erase(old_keys[3]);
419 owned_data.erase(old_keys[4]);
420 owned_data.erase(old_keys[5]);
421 owned_data.erase(old_keys[6]);
422 owned_data.erase(old_keys[7]);
423
424 owned_data.add_obj(new_key, std::move(new_pdat));
425 }
426} // namespace shamrock::scheduler
Shamrock communication buffers.
function to run load balancing with the hilbert curve
PatchData handling.
std::uint32_t u32
32 bit unsigned integer
std::uint64_t u64
64 bit unsigned integer
std::int32_t i32
32 bit integer
Handle the patch list of the mpi scheduler.
std::vector< shamrock::patch::Patch > global
contain the list of all patches in the simulation
A buffer allocated in USM (Unified Shared Memory)
Shamrock communication buffers.
static sham::DeviceBuffer< u8 > convert_usm(CommunicationBuffer &&buf)
destroy the buffer and recover the held object
bool check_main_field_type()
check that main field (id=0)is of type T
PatchDataLayer container class, the layout is described in patchdata_layout.
Store the information related to the size of the simulation box to convert patch integer coordinates ...
Definition SimBox.hpp:35
std::tuple< T, T > patch_coord_to_domain(const Patch &p) const
get the patch coordinates on the domain
Definition SimBox.hpp:300
void merge_patchdata(u64 new_key, const std::array< u64, 8 > old_keys)
merge 8 old patchdata into one
shamrock::patch::SimulationBoxInfo sim_box
simulation box geometry info
void apply_change_list(const shamrock::scheduler::LoadBalancingChangeList &change_list, SchedulerPatchList &patch_list)
apply a load balancing change list to shuffle patchdata arround the cluster
void split_patchdata(u64 key_orginal, const std::array< shamrock::patch::Patch, 8 > patches)
split a patchdata into 8 childs according to the 8 patches in arguments
shambase::DistributedData< PatchData > owned_data
map container for patchdata owned by the current node (layout : id_patch,data)
This header file contains utility functions related to exception handling in the code.
void throw_with_loc(std::string message, SourceLocation loc=SourceLocation{})
Throw an exception and append the source location to it.
T & get_check_ref(const std::unique_ptr< T > &ptr, SourceLocation loc=SourceLocation())
Takes a std::unique_ptr and returns a reference to the object it holds. It throws a std::runtime_erro...
Definition memory.hpp:110
auto extract_pointer(std::unique_ptr< T > &o, SourceLocation loc=SourceLocation()) -> T
extract content out of unique_ptr
Definition memory.hpp:227
i32 world_rank()
Gives the rank of the current process in the MPI communicator.
Definition worldInfo.cpp:40
Utilities for safe type narrowing conversions.
constexpr i32 i32_max
i32 max value
header for PatchData related function and declaration
#define __shamrock_stack_entry()
Macro to create a stack entry.
void Get_count(const MPI_Status *status, MPI_Datatype datatype, int *count)
MPI wrapper for MPI_Get_count.
Definition wrapper.cpp:222
void Irecv(void *buf, int count, MPI_Datatype datatype, int source, int tag, MPI_Comm comm, MPI_Request *request)
MPI wrapper for MPI_Irecv.
Definition wrapper.cpp:102
void Probe(int source, int tag, MPI_Comm comm, MPI_Status *status)
MPI wrapper for MPI_Probe.
Definition wrapper.cpp:201
void Isend(const void *buf, int count, MPI_Datatype datatype, int dest, int tag, MPI_Comm comm, MPI_Request *request)
MPI wrapper for MPI_Isend.
Definition wrapper.cpp:85