30#define NEW_LB_APPLY_IMPL
32#ifndef NEW_LB_APPLY_IMPL
36namespace shamrock::scheduler {
38#ifdef NEW_LB_APPLY_IMPL
40 std::unique_ptr<shamcomm::CommunicationBuffer> buf;
45 void send_messages(std::vector<Message> &msgs, std::vector<MPI_Request> &rqs) {
47 for (
auto &msg : msgs) {
48 rqs.push_back(MPI_Request{});
49 u32 rq_index = rqs.size() - 1;
50 auto &rq = rqs[rq_index];
52 u64 bsize = msg.buf->get_size();
55 "the following mpi comm assume that we can send longs to pack 8byte");
57 u64 lcount = bsize / 8;
73 void recv_probe_messages(std::vector<Message> &msgs, std::vector<MPI_Request> &rqs) {
75 for (
auto &msg : msgs) {
76 rqs.push_back(MPI_Request{});
77 u32 rq_index = rqs.size() - 1;
78 auto &rq = rqs[rq_index];
85 msg.buf = std::make_unique<shamcomm::CommunicationBuffer>(
87 shamsys::instance::get_compute_scheduler_ptr());
110 ser.allocate(pdat.serialize_buf_byte_size());
111 pdat.serialize_buf(ser);
112 return ser.finalize();
118 shamsys::instance::get_compute_scheduler_ptr(),
120 return shamrock::patch::PatchDataLayer::deserialize_buf(ser, pdl_ptr);
123 std::vector<Message> send_payloads;
124 for (
const ChangeOp op : change_list.change_ops) {
127 auto &patchdata =
owned_data.get(op.patch_id);
131 send_payloads.push_back(
133 std::make_unique<shamcomm::CommunicationBuffer>(
134 std::move(tmp), shamsys::instance::get_compute_scheduler_ptr()),
140 std::vector<MPI_Request> rqs;
141 send_messages(send_payloads, rqs);
143 std::vector<Message> recv_payloads;
144 for (
const ChangeOp op : change_list.change_ops) {
145 auto &id_patch = op.patch_id;
149 recv_payloads.push_back(
151 std::unique_ptr<shamcomm::CommunicationBuffer>{},
158 recv_probe_messages(recv_payloads, rqs);
160 std::vector<MPI_Status> st_lst(rqs.size());
161 shamcomm::mpi::Waitall(rqs.size(), rqs.data(), st_lst.data());
165 for (
const ChangeOp op : change_list.change_ops) {
166 auto &id_patch = op.patch_id;
170 Message &msg = recv_payloads[idx];
177 owned_data.add_obj(id_patch, deserializer(std::move(buf)));
184 for (
const ChangeOp op : change_list.change_ops) {
185 auto &id_patch = op.patch_id;
187 patch_list.
global[op.patch_idx].node_owner_id = op.rank_owner_new;
203 std::vector<PatchDataMpiRequest> rq_lst;
208 for (
const ChangeOp op : change_list.change_ops) {
211 auto &patchdata =
owned_data.get(op.patch_id);
212 patchdata_isend(patchdata, rq_lst, op.rank_owner_new, op.tag_comm, MPI_COMM_WORLD);
217 for (
const ChangeOp op : change_list.change_ops) {
218 auto &id_patch = op.patch_id;
223 patchdata_irecv_probe(
232 waitall_pdat_mpi_rq(rq_lst);
235 for (
const ChangeOp op : change_list.change_ops) {
236 auto &id_patch = op.patch_id;
238 patch_list.
global[op.patch_idx].node_owner_id = op.rank_owner_new;
248 template<
class Vectype>
249 void split_patchdata(
252 const std::array<shamrock::patch::Patch, 8> patches,
253 std::array<std::reference_wrapper<shamrock::patch::PatchDataLayer>, 8> pdats) {
255 using ptype =
typename shambase::VectorProperties<Vectype>::component_type;
266 original_pd.split_patchdata<Vectype>(
268 {bmin_p0, bmin_p1, bmin_p2, bmin_p3, bmin_p4, bmin_p5, bmin_p6, bmin_p7},
269 {bmax_p0, bmax_p1, bmax_p2, bmax_p3, bmax_p4, bmax_p5, bmax_p6, bmax_p7});
272 template void split_patchdata<f32_3>(
275 const std::array<shamrock::patch::Patch, 8> patches,
276 std::array<std::reference_wrapper<shamrock::patch::PatchDataLayer>, 8> pdats);
278 template void split_patchdata<f64_3>(
281 const std::array<shamrock::patch::Patch, 8> patches,
282 std::array<std::reference_wrapper<shamrock::patch::PatchDataLayer>, 8> pdats);
284 template void split_patchdata<u32_3>(
287 const std::array<shamrock::patch::Patch, 8> patches,
288 std::array<std::reference_wrapper<shamrock::patch::PatchDataLayer>, 8> pdats);
290 template void split_patchdata<u64_3>(
293 const std::array<shamrock::patch::Patch, 8> patches,
294 std::array<std::reference_wrapper<shamrock::patch::PatchDataLayer>, 8> pdats);
296 template void split_patchdata<i64_3>(
299 const std::array<shamrock::patch::Patch, 8> patches,
300 std::array<std::reference_wrapper<shamrock::patch::PatchDataLayer>, 8> pdats);
303 u64 key_orginal,
const std::array<shamrock::patch::Patch, 8> patches) {
324 shamrock::scheduler::split_patchdata<f32_3>(
325 original_pd,
sim_box, patches, {pd0, pd1, pd2, pd3, pd4, pd5, pd6, pd7});
328 shamrock::scheduler::split_patchdata<f64_3>(
329 original_pd,
sim_box, patches, {pd0, pd1, pd2, pd3, pd4, pd5, pd6, pd7});
332 shamrock::scheduler::split_patchdata<u32_3>(
333 original_pd,
sim_box, patches, {pd0, pd1, pd2, pd3, pd4, pd5, pd6, pd7});
336 shamrock::scheduler::split_patchdata<u64_3>(
337 original_pd,
sim_box, patches, {pd0, pd1, pd2, pd3, pd4, pd5, pd6, pd7});
340 shamrock::scheduler::split_patchdata<i64_3>(
341 original_pd,
sim_box, patches, {pd0, pd1, pd2, pd3, pd4, pd5, pd6, pd7});
344 "the main field does not match any");
349 owned_data.add_obj(patches[0].id_patch, std::move(pd0));
350 owned_data.add_obj(patches[1].id_patch, std::move(pd1));
351 owned_data.add_obj(patches[2].id_patch, std::move(pd2));
352 owned_data.add_obj(patches[3].id_patch, std::move(pd3));
353 owned_data.add_obj(patches[4].id_patch, std::move(pd4));
354 owned_data.add_obj(patches[5].id_patch, std::move(pd5));
355 owned_data.add_obj(patches[6].id_patch, std::move(pd6));
356 owned_data.add_obj(patches[7].id_patch, std::move(pd7));
373 "patchdata for key=%d was not owned by the node", old_keys[0]));
377 "patchdata for key=%d was not owned by the node", old_keys[1]));
381 "patchdata for key=%d was not owned by the node", old_keys[2]));
385 "patchdata for key=%d was not owned by the node", old_keys[3]));
389 "patchdata for key=%d was not owned by the node", old_keys[4]));
393 "patchdata for key=%d was not owned by the node", old_keys[5]));
397 "patchdata for key=%d was not owned by the node", old_keys[6]));
401 "patchdata for key=%d was not owned by the node", old_keys[7]));
406 new_pdat.insert_elements(search0->second);
407 new_pdat.insert_elements(search1->second);
408 new_pdat.insert_elements(search2->second);
409 new_pdat.insert_elements(search3->second);
410 new_pdat.insert_elements(search4->second);
411 new_pdat.insert_elements(search5->second);
412 new_pdat.insert_elements(search6->second);
413 new_pdat.insert_elements(search7->second);
424 owned_data.add_obj(new_key, std::move(new_pdat));
Shamrock communication buffers.
function to run load balancing with the hilbert curve
std::uint32_t u32
32 bit unsigned integer
std::uint64_t u64
64 bit unsigned integer
std::int32_t i32
32 bit integer
Handle the patch list of the mpi scheduler.
std::vector< shamrock::patch::Patch > global
contain the list of all patches in the simulation
A buffer allocated in USM (Unified Shared Memory)
Shamrock communication buffers.
static sham::DeviceBuffer< u8 > convert_usm(CommunicationBuffer &&buf)
destroy the buffer and recover the held object
bool check_main_field_type()
check that main field (id=0)is of type T
PatchDataLayer container class, the layout is described in patchdata_layout.
Store the information related to the size of the simulation box to convert patch integer coordinates ...
std::tuple< T, T > patch_coord_to_domain(const Patch &p) const
get the patch coordinates on the domain
void merge_patchdata(u64 new_key, const std::array< u64, 8 > old_keys)
merge 8 old patchdata into one
shamrock::patch::SimulationBoxInfo sim_box
simulation box geometry info
void apply_change_list(const shamrock::scheduler::LoadBalancingChangeList &change_list, SchedulerPatchList &patch_list)
apply a load balancing change list to shuffle patchdata arround the cluster
void split_patchdata(u64 key_orginal, const std::array< shamrock::patch::Patch, 8 > patches)
split a patchdata into 8 childs according to the 8 patches in arguments
shambase::DistributedData< PatchData > owned_data
map container for patchdata owned by the current node (layout : id_patch,data)
This header file contains utility functions related to exception handling in the code.
void throw_with_loc(std::string message, SourceLocation loc=SourceLocation{})
Throw an exception and append the source location to it.
T & get_check_ref(const std::unique_ptr< T > &ptr, SourceLocation loc=SourceLocation())
Takes a std::unique_ptr and returns a reference to the object it holds. It throws a std::runtime_erro...
auto extract_pointer(std::unique_ptr< T > &o, SourceLocation loc=SourceLocation()) -> T
extract content out of unique_ptr
i32 world_rank()
Gives the rank of the current process in the MPI communicator.
Utilities for safe type narrowing conversions.
constexpr i32 i32_max
i32 max value
header for PatchData related function and declaration
#define __shamrock_stack_entry()
Macro to create a stack entry.
void Get_count(const MPI_Status *status, MPI_Datatype datatype, int *count)
MPI wrapper for MPI_Get_count.
void Irecv(void *buf, int count, MPI_Datatype datatype, int source, int tag, MPI_Comm comm, MPI_Request *request)
MPI wrapper for MPI_Irecv.
void Probe(int source, int tag, MPI_Comm comm, MPI_Status *status)
MPI wrapper for MPI_Probe.
void Isend(const void *buf, int count, MPI_Datatype datatype, int dest, int tag, MPI_Comm comm, MPI_Request *request)
MPI wrapper for MPI_Isend.