Shamrock 2025.10.0
Astrophysical Code
Loading...
Searching...
No Matches
PatchScheduler.cpp
Go to the documentation of this file.
1// -------------------------------------------------------//
2//
3// SHAMROCK code for hydrodynamics
4// Copyright (c) 2021-2026 Timothée David--Cléris <tim.shamrock@proton.me>
5// SPDX-License-Identifier: CeCILL Free Software License Agreement v2.1
6// Shamrock is licensed under the CeCILL 2.1 License, see LICENSE for more information
7//
8// -------------------------------------------------------//
9
19#include "shambase/string.hpp"
20#include "shambase/time.hpp"
21#include "shambackends/math.hpp"
31#include <ctime>
32#include <memory>
33#include <optional>
34#include <sstream>
35#include <stdexcept>
36#include <vector>
37
38// TODO move types init out
39void PatchScheduler::init_mpi_required_types() {
40
41 // if(!patch::is_mpi_patch_type_active()){
42 // patch::create_MPI_patch_type();
43 // }
44}
45
46void PatchScheduler::free_mpi_required_types() {
47
48 // if(patch::is_mpi_patch_type_active()){
49 // patch::free_MPI_patch_type();
50 // }
51}
52
53template<u32 dim>
54void PatchScheduler::make_patch_base_grid(std::array<u32, dim> patch_count) {
55
56 static_assert(dim == 3, "this is not implemented for dim != 3");
57
58 u32 max_lin_patch_count = 0;
59 for (u32 i = 0; i < dim; i++) {
60 max_lin_patch_count = sycl::max(max_lin_patch_count, patch_count[i]);
61 }
62
63 u64 coord_div_fact = sham::roundup_pow2_clz(max_lin_patch_count);
64
65 u64 sz_root_patch = PatchScheduler::max_axis_patch_coord_length / coord_div_fact;
66
67 std::vector<shamrock::patch::PatchCoord<3>> coords;
68 for (u32 x = 0; x < patch_count[0]; x++) {
69 for (u32 y = 0; y < patch_count[1]; y++) {
70 for (u32 z = 0; z < patch_count[2]; z++) {
72
73 coord.coord_min[0] = sz_root_patch * (x);
74 coord.coord_min[1] = sz_root_patch * (y);
75 coord.coord_min[2] = sz_root_patch * (z);
76 coord.coord_max[0] = sz_root_patch * (x + 1) - 1;
77 coord.coord_max[1] = sz_root_patch * (y + 1) - 1;
78 coord.coord_max[2] = sz_root_patch * (z + 1) - 1;
79
80 coords.push_back(coord);
81 }
82 }
83 }
84
86 bounds.coord_min[0] = 0;
87 bounds.coord_min[1] = 0;
88 bounds.coord_min[2] = 0;
89 bounds.coord_max[0] = sz_root_patch * patch_count[0] - 1;
90 bounds.coord_max[1] = sz_root_patch * patch_count[1] - 1;
91 bounds.coord_max[2] = sz_root_patch * patch_count[2] - 1;
92
93 get_sim_box().set_patch_coord_bounding_box(bounds);
94
95 add_root_patches(coords);
96}
97
98template void PatchScheduler::make_patch_base_grid<3>(std::array<u32, 3> patch_count);
99
101 std::vector<shamrock::patch::PatchCoord<3>> coords) {
102
103 using namespace shamrock::patch;
104
105 std::vector<u64> ret;
106
107 for (auto coord : coords) {
108
109 u32 node_owner_id = 0;
110
111 Patch root;
114 root.load_value = 0;
115 root.coord_min[0] = coord.coord_min[0];
116 root.coord_min[1] = coord.coord_min[1];
117 root.coord_min[2] = coord.coord_min[2];
118 root.coord_max[0] = coord.coord_max[0];
119 root.coord_max[1] = coord.coord_max[1];
120 root.coord_max[2] = coord.coord_max[2];
121 root.node_owner_id = node_owner_id;
122
123 patch_list.global.push_back(root);
125
126 if (shamcomm::world_rank() == node_owner_id) {
127 patch_data.owned_data.add_obj(root.id_patch, PatchDataLayer(get_layout_ptr_old()));
128 shamlog_debug_sycl_ln("Scheduler", "adding patch data");
129 } else {
130 shamlog_debug_sycl_ln(
131 "Scheduler",
132 "patch data wasn't added rank =",
134 " ower =",
135 node_owner_id);
136 }
137
138 patch_tree.insert_root_node(root.id_patch, coord);
139
140 ret.push_back(root.id_patch);
141
142 // auto [bmin,bmax] = get_sim_box().patch_coord_to_domain<u64_3>(root);
143 //
144 //
145 // shamlog_debug_ln("Scheduler", "adding patch : [ (",
146 // coord.x_min,
147 // coord.y_min,
148 // coord.z_min,") ] [ (",
149 // coord.x_max,
150 // coord.y_max,
151 // coord.z_max,") ]", bmin,bmax
152 //);
153 }
154
155 // build_local() is declared as nodiscard
156 (void) patch_list.build_local();
160
162
163 return ret;
164}
165
167
168 shamlog_debug_ln("Scheduler", "pushing data obj cnt =", pdat.get_obj_cnt());
169
170 for_each_patch_data([&](u64 id_patch,
173 auto variant_main = pdl_old().get_main_field_any();
174
175 variant_main.visit([&](auto &arg) {
176 using base_t = typename std::remove_reference<decltype(arg)>::type::field_T;
177
179 auto [bmin, bmax] = get_sim_box().patch_coord_to_domain<base_t>(cur_p);
180
181 shamlog_debug_sycl_ln(
182 "Scheduler", "pushing data in patch ", id_patch, "search range :", bmin, bmax);
183
184 pdat_sched.insert_elements_in_range(pdat, bmin, bmax);
185 } else {
186 throw std::runtime_error("this does not yet work with dimension different from 3");
187 }
188 });
189 });
190}
191
193 using namespace shamrock::patch;
194
195 PatchCoord coord;
196 coord.coord_min[0] = 0;
197 coord.coord_min[1] = 0;
198 coord.coord_min[2] = 0;
199 coord.coord_max[0] = max_axis_patch_coord;
200 coord.coord_max[1] = max_axis_patch_coord;
201 coord.coord_max[2] = max_axis_patch_coord;
202
203 add_root_patches({coord});
204
206}
207
208PatchScheduler::PatchScheduler(
209 const std::shared_ptr<shamrock::patch::PatchDataLayerLayout> &pdl_ptr,
210 u64 crit_split,
211 u64 crit_merge)
212 : pdl_ptr(pdl_ptr),
213 patch_data(
214 pdl_ptr,
215 {{0, 0, 0}, {max_axis_patch_coord, max_axis_patch_coord, max_axis_patch_coord}}) {
216
217 crit_patch_split = crit_split;
218 crit_patch_merge = crit_merge;
219}
220
221PatchScheduler::~PatchScheduler() {}
222
223bool PatchScheduler::should_resize_box(bool node_in) {
224 u16 tmp = node_in;
225 u16 out = 0;
226 shamcomm::mpi::Allreduce(&tmp, &out, 1, mpi_type_u16, MPI_MAX, MPI_COMM_WORLD);
227 return out;
228}
229
230// TODO move Loadbalancing function to template state
231void PatchScheduler::sync_build_LB(bool global_patch_sync, bool balance_load) {
232
234
235 if (global_patch_sync)
237
238 if (balance_load) {
239 // real load balancing
242
243 // exchange data
245 }
246
247 // rebuild local table
249}
250
251template<>
252std::tuple<f32_3, f32_3> PatchScheduler::get_box_tranform() {
253 if (!pdl_old().check_main_field_type<f32_3>())
255 "cannot query single precision box the main field is not of f32_3 type");
256
257 auto [bmin, bmax] = patch_data.sim_box.get_bounding_box<f32_3>();
258
259 f32_3 translate_factor = bmin;
260 f32_3 scale_factor = (bmax - bmin) / LoadBalancer::max_box_sz;
261
262 return {translate_factor, scale_factor};
263}
264
265template<>
266std::tuple<f64_3, f64_3> PatchScheduler::get_box_tranform() {
267 if (!pdl_old().check_main_field_type<f64_3>())
269 "cannot query single precision box the main field is not of f64_3 type");
270
271 auto [bmin, bmax] = patch_data.sim_box.get_bounding_box<f64_3>();
272
273 f64_3 translate_factor = bmin;
274 f64_3 scale_factor = (bmax - bmin) / LoadBalancer::max_box_sz;
275
276 return {translate_factor, scale_factor};
277}
278
279template<>
280std::tuple<f32_3, f32_3> PatchScheduler::get_box_volume() {
281 if (!pdl_old().check_main_field_type<f32_3>())
283 "cannot query single precision box the main field is not of f32_3 type");
284
285 return patch_data.sim_box.get_bounding_box<f32_3>();
286}
287
288template<>
289std::tuple<f64_3, f64_3> PatchScheduler::get_box_volume() {
290 if (!pdl_old().check_main_field_type<f64_3>())
292 "cannot query single precision box the main field is not of f64_3 type");
293
294 return patch_data.sim_box.get_bounding_box<f64_3>();
295}
296
297template<>
298std::tuple<i64_3, i64_3> PatchScheduler::get_box_volume() {
299 if (!pdl_old().check_main_field_type<i64_3>())
301 "cannot query single precision box the main field is not of i64_3 type");
302
303 return patch_data.sim_box.get_bounding_box<i64_3>();
304}
305
306// TODO clean the output of this function
307void PatchScheduler::scheduler_step(bool do_split_merge, bool do_load_balancing) {
308 StackEntry stack_loc{};
309
310 // std::cout << dump_status();
311
312 if (!is_mpi_sycl_interop_active())
314 "sycl mpi interop not initialized");
315
316 shambase::Timer timer;
317 shamlog_debug_ln("Scheduler", "running scheduler step");
318
319 struct SchedulerStepTimers {
320 shambase::Timer global_timer;
321 shambase::Timer metadata_sync;
322 std::optional<shambase::Timer> global_idx_map_build = {};
323 std::optional<shambase::Timer> patch_tree_count_reduce = {};
324 std::optional<shambase::Timer> gen_merge_split_rq = {};
325 std::optional<u32_2> split_merge_cnt = {};
326 std::optional<shambase::Timer> apply_splits = {};
327 std::optional<shambase::Timer> load_balance_compute = {};
328 std::optional<u32> load_balance_move_op_cnt = {};
329 std::optional<shambase::Timer> load_balance_apply = {};
330
331 void print_stats() {
332 if (shamcomm::world_rank() == 0) {
333 f64 total = global_timer.nanosec;
334 std::string str = "";
335 str += "Scheduler step timings : ";
336 str += shambase::format(
337 "\n metadata sync : {:<10} ({:2.1f}%)",
338 metadata_sync.get_time_str(),
339 f64(100 * (metadata_sync.nanosec / total)));
340 if (patch_tree_count_reduce) {
341 str += shambase::format(
342 "\n patch tree reduce : {:<10} ({:2.1f}%)",
343 patch_tree_count_reduce->get_time_str(),
344 100 * (patch_tree_count_reduce->nanosec / total));
345 }
346 if (gen_merge_split_rq) {
347 str += shambase::format(
348 "\n gen split merge : {:<10} ({:2.1f}%)",
349 gen_merge_split_rq->get_time_str(),
350 100 * (gen_merge_split_rq->nanosec / total));
351 }
352 if (split_merge_cnt) {
353 str += shambase::format(
354 "\n split / merge op : {}/{}",
355 split_merge_cnt->x(),
356 split_merge_cnt->y());
357 }
358 if (apply_splits) {
359 str += shambase::format(
360 "\n apply split merge : {:<10} ({:2.1f}%)",
361 apply_splits->get_time_str(),
362 100 * (apply_splits->nanosec / total));
363 }
364 if (load_balance_compute) {
365 str += shambase::format(
366 "\n LB compute : {:<10} ({:2.1f}%)",
367 load_balance_compute->get_time_str(),
368 100 * (load_balance_compute->nanosec / total));
369 }
370 if (load_balance_move_op_cnt) {
371 str += shambase::format(
372 "\n LB move op cnt : {}", *load_balance_move_op_cnt);
373 }
374 if (load_balance_apply) {
375 str += shambase::format(
376 "\n LB apply : {:<10} ({:2.1f}%)",
377 load_balance_apply->get_time_str(),
378 100 * (load_balance_apply->nanosec / total));
379 }
380 logger::info_ln("Scheduler", str);
381 }
382 }
383 } timers;
384
385 timers.global_timer.start();
386
388
389 timers.metadata_sync.start();
391 timers.metadata_sync.end();
392
393 // std::cout << dump_status();
394
395 std::unordered_set<u64> split_rq;
396 std::unordered_set<u64> merge_rq;
397
398 if (do_split_merge) {
399 // std::cout << dump_status() << std::endl;
400
401 // std::cout << "build_global_idx_map" <<std::endl;
402 timers.global_idx_map_build = shambase::Timer{};
403 timers.global_idx_map_build->start(); // TODO check if it it used outside of split merge ->
404 // maybe need to be put before the if
406 timers.global_idx_map_build->end();
407
408 // std::cout << dump_status() << std::endl;
409
410 // std::cout << "tree partial_values_reduction" <<std::endl;
411 timers.patch_tree_count_reduce = shambase::Timer{};
412 timers.patch_tree_count_reduce->start();
414 timers.patch_tree_count_reduce->end();
415
416 // std::cout << dump_status() << std::endl;
417
418 // Generate merge and split request
419 timers.gen_merge_split_rq = shambase::Timer{};
420 timers.gen_merge_split_rq->start();
423 timers.gen_merge_split_rq->end();
424
425 timers.split_merge_cnt = u32_2{split_rq.size(), merge_rq.size()};
426 /*
427 std::cout << " |-> split rq : ";
428 for(u64 i : split_rq){
429 std::cout << i << " ";
430 }std::cout << std::endl;
431 //*/
432
433 /*
434 std::cout << " |-> merge rq : ";
435 for(u64 i : merge_rq){
436 std::cout << i << " ";
437 }std::cout << std::endl;
438 //*/
439
440 // std::cout << dump_status() << std::endl;
441
442 // std::cout << "split_patches" <<std::endl;
443 timers.apply_splits = shambase::Timer{};
444 timers.apply_splits->start();
445 split_patches(split_rq);
446 timers.apply_splits->end();
447
448 // std::cout << dump_status() << std::endl;
449
450 // check not necessary if no splits
452
453 set_patch_pack_values(merge_rq);
454 }
455
456 if (do_load_balancing) {
457 StackEntry stack_loc{};
458 timers.load_balance_compute = shambase::Timer{};
459 timers.load_balance_compute->start();
460 // generate LB change list
463 timers.load_balance_compute->end();
464
465 timers.load_balance_move_op_cnt = change_list.change_ops.size();
466
467 timers.load_balance_apply = shambase::Timer{};
468 timers.load_balance_apply->start();
469 // apply LB change list
471 timers.load_balance_apply->end();
472 }
473
474 // std::cout << dump_status();
475
476 if (do_split_merge) {
478 merge_patches(merge_rq);
479 }
480
481 // TODO should be moved out of the scheduler step
485 patch_list.build_global_idx_map(); // TODO check if required : added because possible bug
486 // because of for each patch & serial patch tree
487 // update_local_dtcnt_value();
488 // update_local_load_value(); disable the load value compute it should be done only in the
489 // models
490
491 if (split_rq.size() > 0 || merge_rq.size() > 0) {
493 }
494
495 // std::cout << dump_status();
496
497 timers.global_timer.end();
498 timers.print_stats();
499}
500
501/*
502void SchedulerMPI::scheduler_step(bool do_split_merge,bool do_load_balancing){
503
504 // update patch list
505 patch_list.sync_global();
506
507
508 if(do_split_merge){
509 // rebuild patch index map
510 patch_list.build_global_idx_map();
511
512 // apply reduction on leafs and corresponding parents
513 patch_tree.partial_values_reduction(
514 patch_list.global,
515 patch_list.id_patch_to_global_idx);
516
517 // Generate merge and split request
518 std::unordered_set<u64> split_rq = patch_tree.get_split_request(crit_patch_split);
519 std::unordered_set<u64> merge_rq = patch_tree.get_merge_request(crit_patch_merge);
520
521
522 // apply split requests
523 // update patch_list.global same on every node
524 // and split patchdata accordingly if owned
525 // & update tree
526 split_patches(split_rq);
527
528 // update packing index
529 // same operation on every cluster nodes
530 set_patch_pack_values(merge_rq);
531
532 // update patch list
533 // necessary to update load values in splitted patches
534 // alternative : disable this step and set fake load values (load parent / 8)
535 //alternative impossible if gravity because we have to compute the multipole
536 owned_patch_id = patch_list.build_local();
537 patch_list.sync_global();
538 }
539
540 if(do_load_balancing){
541 // generate LB change list
542 std::vector<std::tuple<u64, i32, i32,i32>> change_list =
543 make_change_list(patch_list.global);
544
545 // apply LB change list
546 patch_data.apply_change_list(change_list, patch_list);
547 }
548
549 if(do_split_merge){
550 // apply merge requests
551 // & update tree
552 merge_patches(merge_rq);
553
554
555
556 // if(Merge) update patch list
557 if(! merge_rq.empty()){
558 owned_patch_id = patch_list.build_local();
559 patch_list.sync_global();
560 }
561 }
562
563 //rebuild local table
564 owned_patch_id = patch_list.build_local();
565}
566//*/
567
568std::string PatchScheduler::dump_status() {
569
570 using namespace shamrock::patch;
571
572 std::stringstream ss;
573
574 ss << "----- MPI Scheduler dump -----\n\n";
575 ss << " -> SchedulerPatchList\n";
576
577 ss << " len global : " << patch_list.global.size() << "\n";
578 ss << " len local : " << patch_list.local.size() << "\n";
579
580 ss << " global content : \n";
581 for (Patch &p : patch_list.global) {
582
583 ss << " -> " << p.id_patch << " : " << p.load_value << " " << p.node_owner_id << " "
584 << p.pack_node_index << " "
585 << "( [" << p.coord_min[0] << "," << p.coord_max[0] << "] "
586 << " [" << p.coord_min[1] << "," << p.coord_max[1] << "] "
587 << " [" << p.coord_min[2] << "," << p.coord_max[2] << "] )\n";
588 }
589 ss << " local content : \n";
590 for (Patch &p : patch_list.local) {
591
592 ss << " -> id : " << p.id_patch << " : " << p.load_value << " " << p.node_owner_id
593 << " " << p.pack_node_index << " "
594 << "( [" << p.coord_min[0] << "," << p.coord_max[0] << "] "
595 << " [" << p.coord_min[1] << "," << p.coord_max[1] << "] "
596 << " [" << p.coord_min[2] << "," << p.coord_max[2] << "] )\n";
597 }
598
599 ss << shambase::format(
600 "patch_list.id_patch_to_global_idx :\n{}\n", patch_list.id_patch_to_global_idx);
601 ss << shambase::format(
602 "patch_list.id_patch_to_local_idx :\n{}\n", patch_list.id_patch_to_local_idx);
603
604 ss << " -> SchedulerPatchData\n";
605 ss << " owned data : \n";
606
607 patch_data.for_each_patchdata([&](u64 patch_id, shamrock::patch::PatchDataLayer &pdat) {
608 ss << "patch id : " << patch_id << " len = " << pdat.get_obj_cnt() << "\n";
609 });
610
611 /*
612 for(auto & [k,pdat] : patch_data.owned_data){
613 ss << " -> id : " << k << " len : (" <<
614 pdat.pos_s.size() << " " <<pdat.pos_d.size() << " " <<
615 pdat.U1_s.size() << " " <<pdat.U1_d.size() << " " <<
616 pdat.U3_s.size() << " " <<pdat.U3_d.size() << " "
617 << ")\n";
618 }
619 */
620
621 ss << " -> SchedulerPatchTree\n";
622
623 for (auto &[k, pnode] : patch_tree.tree) {
624 ss << shambase::format(
625 " -> id : {} -> ({}) <=> {} [{}, {}] (cl={} il={} l={} pid={})\n",
626 k,
627 pnode.tree_node.childs_nid,
628 pnode.linked_patchid,
629 pnode.patch_coord.coord_min,
630 pnode.patch_coord.coord_max,
631 pnode.tree_node.child_are_all_leafs,
632 pnode.tree_node.is_leaf,
633 pnode.tree_node.level,
634 pnode.tree_node.parent_nid);
635 }
636
637 return ss.str();
638}
639
640std::string PatchScheduler::format_patch_coord(shamrock::patch::Patch p) {
641 std::string ret;
642 if (pdl_old().check_main_field_type<f32_3>()) {
643 auto [bmin, bmax] = patch_data.sim_box.patch_coord_to_domain<f32_3>(p);
644 ret = shambase::format("coord = {} {}", bmin, bmax);
645 } else if (pdl_old().check_main_field_type<f64_3>()) {
646 auto [bmin, bmax] = patch_data.sim_box.patch_coord_to_domain<f64_3>(p);
647 ret = shambase::format("coord = {} {}", bmin, bmax);
648 } else if (pdl_old().check_main_field_type<u32_3>()) {
649 auto [bmin, bmax] = patch_data.sim_box.patch_coord_to_domain<u32_3>(p);
650 ret = shambase::format("coord = {} {}", bmin, bmax);
651 } else if (pdl_old().check_main_field_type<u64_3>()) {
652 auto [bmin, bmax] = patch_data.sim_box.patch_coord_to_domain<u64_3>(p);
653 ret = shambase::format("coord = {} {}", bmin, bmax);
654 } else {
656 "the main field does not match any");
657 }
658 return ret;
659}
660
661template<class vec>
662void check_locality_t(PatchScheduler &sched) {
663
664 StackEntry stack_loc{};
665
666 using namespace shamrock::patch;
668 PatchDataField<vec> &main_field = pdat.get_field<vec>(0);
669 auto [bmin_p0, bmax_p0] = sched.patch_data.sim_box.patch_coord_to_domain<vec>(p);
670
671 main_field.check_err_range(
672 [&](vec val, vec vmin, vec vmax) {
673 return Patch::is_in_patch_converted(val, vmin, vmax);
674 },
675 bmin_p0,
676 bmax_p0,
677 shambase::format("patch id = {}", pid));
678 });
679}
680
681void PatchScheduler::check_patchdata_locality_correctness() {
682
683 StackEntry stack_loc{};
684
685 if (pdl_old().check_main_field_type<f32_3>()) {
686 check_locality_t<f32_3>(*this);
687 } else if (pdl_old().check_main_field_type<f64_3>()) {
688 check_locality_t<f64_3>(*this);
689 } else if (pdl_old().check_main_field_type<u32_3>()) {
690 check_locality_t<u32_3>(*this);
691 } else if (pdl_old().check_main_field_type<u64_3>()) {
692 check_locality_t<u64_3>(*this);
693 } else if (pdl_old().check_main_field_type<i64_3>()) {
694 check_locality_t<i64_3>(*this);
695 } else {
697 "the main field does not match any");
698 }
699}
700
701void PatchScheduler::split_patches(std::unordered_set<u64> split_rq) {
702 StackEntry stack_loc{};
703 for (u64 tree_id : split_rq) {
704
705 patch_tree.split_node(tree_id);
706 PatchTree::Node &splitted_node = patch_tree.tree[tree_id];
707
708 shamrock::patch::Patch old_patch
709 = patch_list.global[patch_list.id_patch_to_global_idx[splitted_node.linked_patchid]];
710
711 auto [idx_p0, idx_p1, idx_p2, idx_p3, idx_p4, idx_p5, idx_p6, idx_p7]
712 = patch_list.split_patch(splitted_node.linked_patchid);
713
714 u64 old_patch_id = splitted_node.linked_patchid;
715
716 splitted_node.linked_patchid = u64_max;
717 patch_tree.tree[splitted_node.tree_node.childs_nid[0]].linked_patchid
718 = patch_list.global[idx_p0].id_patch;
719 patch_tree.tree[splitted_node.tree_node.childs_nid[1]].linked_patchid
720 = patch_list.global[idx_p1].id_patch;
721 patch_tree.tree[splitted_node.tree_node.childs_nid[2]].linked_patchid
722 = patch_list.global[idx_p2].id_patch;
723 patch_tree.tree[splitted_node.tree_node.childs_nid[3]].linked_patchid
724 = patch_list.global[idx_p3].id_patch;
725 patch_tree.tree[splitted_node.tree_node.childs_nid[4]].linked_patchid
726 = patch_list.global[idx_p4].id_patch;
727 patch_tree.tree[splitted_node.tree_node.childs_nid[5]].linked_patchid
728 = patch_list.global[idx_p5].id_patch;
729 patch_tree.tree[splitted_node.tree_node.childs_nid[6]].linked_patchid
730 = patch_list.global[idx_p6].id_patch;
731 patch_tree.tree[splitted_node.tree_node.childs_nid[7]].linked_patchid
732 = patch_list.global[idx_p7].id_patch;
733
734 try {
736 old_patch_id,
737 {patch_list.global[idx_p0],
738 patch_list.global[idx_p1],
739 patch_list.global[idx_p2],
740 patch_list.global[idx_p3],
741 patch_list.global[idx_p4],
742 patch_list.global[idx_p5],
743 patch_list.global[idx_p6],
744 patch_list.global[idx_p7]});
745 } catch (const PatchDataRangeCheckError &e) {
746 logger::err_ln("SchedulerPatchData", "catched range issue with patchdata split");
747
748 logger::raw_ln(" old patch", old_patch.id_patch, format_patch_coord(old_patch));
749
750 logger::err_ln("Scheduler", "global patch list :");
751 for (shamrock::patch::Patch &p : patch_list.global) {
752 logger::raw_ln(" patch", p.id_patch, format_patch_coord(p));
753 }
754
756 "\n Initial error : "
757 + shambase::increase_indent(std::string("\n") + e.what(), "\n |"));
758 }
759 }
760}
761
762inline void PatchScheduler::merge_patches(std::unordered_set<u64> merge_rq) {
763 StackEntry stack_loc{};
764 for (u64 tree_id : merge_rq) {
765
766 PatchTree::Node &to_merge_node = patch_tree.tree[tree_id];
767
768 // std::cout << "merging patch tree id : " << tree_id << "\n";
769
770 u64 patch_id0 = patch_tree.tree[to_merge_node.tree_node.childs_nid[0]].linked_patchid;
771 u64 patch_id1 = patch_tree.tree[to_merge_node.tree_node.childs_nid[1]].linked_patchid;
772 u64 patch_id2 = patch_tree.tree[to_merge_node.tree_node.childs_nid[2]].linked_patchid;
773 u64 patch_id3 = patch_tree.tree[to_merge_node.tree_node.childs_nid[3]].linked_patchid;
774 u64 patch_id4 = patch_tree.tree[to_merge_node.tree_node.childs_nid[4]].linked_patchid;
775 u64 patch_id5 = patch_tree.tree[to_merge_node.tree_node.childs_nid[5]].linked_patchid;
776 u64 patch_id6 = patch_tree.tree[to_merge_node.tree_node.childs_nid[6]].linked_patchid;
777 u64 patch_id7 = patch_tree.tree[to_merge_node.tree_node.childs_nid[7]].linked_patchid;
778
779 // print list of patch that will merge
780 // std::cout << format(" -> (%d %d %d %d %d %d %d %d)\n", patch_id0, patch_id1, patch_id2,
781 // patch_id3, patch_id4, patch_id5, patch_id6, patch_id7);
782
783 if (patch_list.global[patch_list.id_patch_to_global_idx[patch_id0]].node_owner_id
786 patch_id0,
787 {patch_id0,
788 patch_id1,
789 patch_id2,
790 patch_id3,
791 patch_id4,
792 patch_id5,
793 patch_id6,
794 patch_id7});
795 }
796
806
807 patch_tree.merge_node_dm1(tree_id);
808
809 to_merge_node.linked_patchid = patch_id0;
810 }
811}
812
813inline void PatchScheduler::set_patch_pack_values(std::unordered_set<u64> merge_rq) {
814
815 for (u64 tree_id : merge_rq) {
816
817 PatchTree::Node &to_merge_node = patch_tree.tree[tree_id];
818
819 u64 idx_pack
820 = patch_list.id_patch_to_global_idx[patch_tree.tree[to_merge_node.get_child_nid(0)]
821 .linked_patchid];
822
823 // std::cout << "node id : " << patch_list.global[idx_pack].id_patch << " should merge with
824 // : ";
825
826 for (u8 i = 1; i < 8; i++) {
827 // std::cout << patch_tree.tree[to_merge_node.get_child_nid(i)].linked_patchid << " ";
830 [patch_tree.tree[to_merge_node.get_child_nid(i)].linked_patchid]]
831 .pack_node_index = idx_pack;
832 } // std::cout << std::endl;
833 }
834}
835
836void PatchScheduler::dump_local_patches(std::string filename) {
837
838 using namespace shamrock::patch;
839
840 std::ofstream fout(filename);
841
842 if (pdl_old().check_main_field_type<f32_3>()) {
843
844 std::tuple<f32_3, f32_3> box_transform = get_box_tranform<f32_3>();
845
846 for (const Patch &p : patch_list.local) {
847
848 f32_3 box_min
849 = f32_3{p.coord_min[0], p.coord_min[1], p.coord_min[2]} * std::get<1>(box_transform)
850 + std::get<0>(box_transform);
851 f32_3 box_max = (f32_3{p.coord_max[0], p.coord_max[1], p.coord_max[2]} + 1)
852 * std::get<1>(box_transform)
853 + std::get<0>(box_transform);
854
855 fout << p.id_patch << "|" << p.load_value << "|" << p.node_owner_id << "|"
856 << p.pack_node_index << "|" << box_min.x() << "|" << box_max.x() << "|"
857 << box_min.y() << "|" << box_max.y() << "|" << box_min.z() << "|" << box_max.z()
858 << "|" << "\n";
859 }
860
861 fout.close();
862
863 } else if (pdl_old().check_main_field_type<f64_3>()) {
864
865 std::tuple<f64_3, f64_3> box_transform = get_box_tranform<f64_3>();
866
867 for (const Patch &p : patch_list.local) {
868
869 f64_3 box_min
870 = f64_3{p.coord_min[0], p.coord_min[1], p.coord_min[2]} * std::get<1>(box_transform)
871 + std::get<0>(box_transform);
872 f64_3 box_max = (f64_3{p.coord_max[0], p.coord_max[1], p.coord_max[3]} + 1)
873 * std::get<1>(box_transform)
874 + std::get<0>(box_transform);
875
876 fout << p.id_patch << "|" << p.load_value << "|" << p.node_owner_id << "|"
877 << p.pack_node_index << "|" << box_min.x() << "|" << box_max.x() << "|"
878 << box_min.y() << "|" << box_max.y() << "|" << box_min.z() << "|" << box_max.z()
879 << "|" << "\n";
880 }
881
882 fout.close();
883
884 } else {
886 "the chosen type for the main field is not handled");
887 }
888}
889
890struct Message {
891 std::unique_ptr<shamcomm::CommunicationBuffer> buf;
892 i32 rank;
893 i32 tag;
894};
895
896void send_messages(std::vector<Message> &msgs, std::vector<MPI_Request> &rqs) {
897 for (auto &msg : msgs) {
898 rqs.push_back(MPI_Request{});
899 u32 rq_index = rqs.size() - 1;
900 auto &rq = rqs[rq_index];
901
902 u64 bsize = msg.buf->get_size();
903 if (bsize % 8 != 0) {
905 "the following mpi comm assume that we can send longs to pack 8byte");
906 }
907 u64 lcount = bsize / 8;
908 if (lcount > i32_max) {
909 shambase::throw_with_loc<std::runtime_error>("The message is too large for MPI");
910 }
911
913 msg.buf->get_ptr(),
914 lcount,
915 get_mpi_type<u64>(),
916 msg.rank,
917 msg.tag,
918 MPI_COMM_WORLD,
919 &rq);
920 }
921}
922
923void recv_probe_messages(std::vector<Message> &msgs, std::vector<MPI_Request> &rqs) {
924
925 for (auto &msg : msgs) {
926 rqs.push_back(MPI_Request{});
927 u32 rq_index = rqs.size() - 1;
928 auto &rq = rqs[rq_index];
929
930 MPI_Status st;
931 i32 cnt;
932 shamcomm::mpi::Probe(msg.rank, msg.tag, MPI_COMM_WORLD, &st);
933 shamcomm::mpi::Get_count(&st, get_mpi_type<u64>(), &cnt);
934
935 msg.buf = std::make_unique<shamcomm::CommunicationBuffer>(
936 cnt * 8, shamsys::instance::get_compute_scheduler_ptr());
937
939 msg.buf->get_ptr(), cnt, get_mpi_type<u64>(), msg.rank, msg.tag, MPI_COMM_WORLD, &rq);
940 }
941}
942
943std::vector<std::unique_ptr<shamrock::patch::PatchDataLayer>> PatchScheduler::gather_data(
944 u32 rank) {
945
946 using namespace shamrock::patch;
947
948 auto plist = this->patch_list.global;
949 auto pdata = this->patch_data.owned_data;
950
951 auto serializer = [](shamrock::patch::PatchDataLayer &pdat) {
952 shamalgs::SerializeHelper ser(shamsys::instance::get_compute_scheduler_ptr());
953 ser.allocate(pdat.serialize_buf_byte_size());
954 pdat.serialize_buf(ser);
955 return ser.finalize();
956 };
957
958 auto deserializer = [&](sham::DeviceBuffer<u8> &&buf) {
959 // exchange the buffer held by the distrib data and give it to the serializer
961 shamsys::instance::get_compute_scheduler_ptr(),
962 std::forward<sham::DeviceBuffer<u8>>(buf));
963 return shamrock::patch::PatchDataLayer::deserialize_buf(ser, get_layout_ptr_old());
964 };
965
966 std::vector<Message> send_payloads;
967
968 for (u32 i = 0; i < plist.size(); i++) {
969 auto &cpatch = plist[i];
970 if (cpatch.node_owner_id == shamcomm::world_rank()) {
971 auto &patchdata = pdata.get(cpatch.id_patch);
972
973 sham::DeviceBuffer<u8> tmp = serializer(patchdata);
974
975 send_payloads.push_back(
976 Message{
977 std::make_unique<shamcomm::CommunicationBuffer>(
978 std::move(tmp), shamsys::instance::get_compute_scheduler_ptr()),
979 0,
980 i32(i)});
981 }
982 }
983
984 std::vector<MPI_Request> rqs;
985 send_messages(send_payloads, rqs);
986
987 std::vector<Message> recv_payloads;
988
989 if (shamcomm::world_rank() == 0) {
990 for (u32 i = 0; i < plist.size(); i++) {
991 recv_payloads.push_back(
992 Message{
993 std::unique_ptr<shamcomm::CommunicationBuffer>{},
994 i32(plist[i].node_owner_id),
995 i32(i)});
996 }
997 }
998
999 // receive
1000 recv_probe_messages(recv_payloads, rqs);
1001
1002 std::vector<MPI_Status> st_lst(rqs.size());
1003 shamcomm::mpi::Waitall(rqs.size(), rqs.data(), st_lst.data());
1004
1005 std::vector<std::unique_ptr<PatchDataLayer>> ret;
1006 for (auto &recv_msg : recv_payloads) {
1008
1010 = shamcomm::CommunicationBuffer::convert_usm(std::move(comm_buf));
1011
1012 ret.push_back(std::make_unique<PatchDataLayer>(deserializer(std::move(buf))));
1013 }
1014
1015 return ret;
1016}
1017
1018nlohmann::json PatchScheduler::serialize_patch_metadata() {
1019
1020 nlohmann::json jsim_box;
1021 patch_data.sim_box.to_json(jsim_box);
1022
1023 return {
1024 {"patchtree", patch_tree},
1025 {"patchlist", patch_list},
1026 {"patchdata_layout", pdl_old()},
1027 {"sim_box", jsim_box},
1028 {"crit_patch_split", crit_patch_split},
1029 {"crit_patch_merge", crit_patch_merge}};
1030}
function to run load balancing with the hilbert curve
Header file describing a Node Instance.
MPI scheduler.
double f64
Alias for double.
std::uint8_t u8
8 bit unsigned integer
std::uint32_t u32
32 bit unsigned integer
std::uint64_t u64
64 bit unsigned integer
std::uint16_t u16
16 bit unsigned integer
std::int32_t i32
32 bit integer
The MPI scheduler.
void for_each_patch_data(Function &&fct)
for each macro for patchadata example usage
SchedulerPatchData patch_data
handle the data of the patches of the scheduler
u64 crit_patch_split
splitting limit (if load value > crit_patch_split => patch split)
PatchTree patch_tree
handle the tree structure of the patches
void scheduler_step(bool do_split_merge, bool do_load_balancing)
scheduler step
SchedulerPatchList patch_list
handle the list of the patches of the scheduler
std::unordered_set< u64 > owned_patch_id
(owned_patch_id = patch_list.build_local())
std::vector< u64 > add_root_patches(std::vector< shamrock::patch::PatchCoord< 3 > > coords)
add a root patch to the scheduler
u64 crit_patch_merge
merging limit (if load value < crit_patch_merge => patch merge)
void allpush_data(shamrock::patch::PatchDataLayer &pdat)
push data in the scheduler The content of pdat as to be the same for each node
void add_root_patch()
add patch to the scheduler
std::vector< shamrock::patch::Patch > local
contain the list of patch owned by the current node
void reset_local_pack_index()
reset Patch's pack index value
std::unordered_map< u64, u64 > id_patch_to_local_idx
id_patch_to_local_idx[patch_id] = index in local patch list
std::vector< shamrock::patch::Patch > global
contain the list of all patches in the simulation
void build_global()
rebuild global from the local list of each tables
void invalidate_load_values()
Invalidate current load values (To use after a change the patches is made)
u64 _next_patch_id
The next available patch id.
std::tuple< u64, u64, u64, u64, u64, u64, u64, u64 > split_patch(u64 id_patch)
split the Patch having id_patch as id and return the index of the 8 subpatches in the global vector
std::unordered_set< u64 > build_local()
select owned patches owned by the node to rebuild local
std::unordered_map< u64, u64 > id_patch_to_global_idx
id_patch_to_global_idx[patch_id] = index in global patch list
void merge_patch(u64 idx0, u64 idx1, u64 idx2, u64 idx3, u64 idx4, u64 idx5, u64 idx6, u64 idx7)
merge the 8 given patches index in the global vector
void build_local_idx_map()
recompute id_patch_to_local_idx
void check_load_values_valid(SourceLocation loc=SourceLocation{})
Check if the load values are valid, throw otherwise.
void build_global_idx_map()
recompute id_patch_to_global_idx
A buffer allocated in USM (Unified Shared Memory)
Class Timer measures the time elapsed since the timer was started.
Definition time.hpp:96
std::string get_time_str() const
Converts the stored nanosecond time to a string representation.
Definition time.hpp:117
void start()
Starts the timer.
Definition time.hpp:106
f64 nanosec
Time in nanosecond.
Definition time.hpp:100
Shamrock communication buffers.
static sham::DeviceBuffer< u8 > convert_usm(CommunicationBuffer &&buf)
destroy the buffer and recover the held object
const var_t & get_main_field_any() const
Get the main field description as a variant object.
PatchDataLayer container class, the layout is described in patchdata_layout.
void insert_elements_in_range(PatchDataLayer &pdat, T bmin, T bmax)
insert elements of pdat only if they are within the range
std::tuple< T, T > get_bounding_box() const
Get the stored bounding box of the domain.
Definition SimBox.hpp:247
void to_json(nlohmann::json &j)
Serializes a SimulationBoxInfo object to a JSON object.
Definition SimBox.cpp:31
std::tuple< T, T > patch_coord_to_domain(const Patch &p) const
get the patch coordinates on the domain
Definition SimBox.hpp:300
static LoadBalancingChangeList make_change_list(std::vector< shamrock::patch::Patch > &global_patch_list)
generate the change list from the list of patch to run the load balancing
static constexpr u64 max_box_sz
maximal value along an axis for the patch coordinate
std::array< u64, 8 > childs_nid
Array of childs node ids.
Node information in the patchtree + held patch info.
std::unordered_set< u64 > get_merge_request(u64 crit_load_merge)
Get list of nodes id to merge.
std::unordered_set< u64 > get_split_request(u64 crit_load_split)
Get list of nodes id to split.
void partial_values_reduction(std::vector< Patch > &plist, const std::unordered_map< u64, u64 > &id_patch_to_global_idx)
update values in leafs and parent_of_only_leaf_key only
void merge_node_dm1(u64 idparent)
merge childs of idparent (id parent should have only leafs as childs)
Definition PatchTree.cpp:66
std::unordered_map< u64, Node > tree
store the tree using a map
Definition PatchTree.hpp:43
void split_node(u64 id)
split a leaf node
Definition PatchTree.cpp:36
void merge_patchdata(u64 new_key, const std::array< u64, 8 > old_keys)
merge 8 old patchdata into one
shamrock::patch::SimulationBoxInfo sim_box
simulation box geometry info
void apply_change_list(const shamrock::scheduler::LoadBalancingChangeList &change_list, SchedulerPatchList &patch_list)
apply a load balancing change list to shuffle patchdata arround the cluster
void split_patchdata(u64 key_orginal, const std::array< shamrock::patch::Patch, 8 > patches)
split a patchdata into 8 childs according to the 8 patches in arguments
shambase::DistributedData< PatchData > owned_data
map container for patchdata owned by the current node (layout : id_patch,data)
This header file contains utility functions related to exception handling in the code.
constexpr T roundup_pow2_clz(T v) noexcept
round up to the next power of two 0 is rounded up to 1 as it is not a pow of 2 every input above the ...
Definition math.hpp:805
void throw_with_loc(std::string message, SourceLocation loc=SourceLocation{})
Throw an exception and append the source location to it.
std::string increase_indent(std::string in, std::string delim="\n ")
Increase indentation of a string.
Definition string.hpp:197
auto extract_pointer(std::unique_ptr< T > &o, SourceLocation loc=SourceLocation()) -> T
extract content out of unique_ptr
Definition memory.hpp:227
i32 world_rank()
Gives the rank of the current process in the MPI communicator.
Definition worldInfo.cpp:40
constexpr u64 u64_max
u64 max value
constexpr i32 i32_max
i32 max value
header for PatchData related function and declaration
This file contains the definition for the stacktrace related functionality.
Patch object that contain generic patch information.
Definition Patch.hpp:33
u64 pack_node_index
this value mean "to pack with index xxx in the global patch table" and not "to pack with id_pach == x...
Definition Patch.hpp:87
u32 node_owner_id
node rank owner of this patch
Definition Patch.hpp:93
u64 load_value
if synchronized contain the load value of the patch
Definition Patch.hpp:88
u64 id_patch
unique key that identify the patch
Definition Patch.hpp:86
header file to manage sycl
void Get_count(const MPI_Status *status, MPI_Datatype datatype, int *count)
MPI wrapper for MPI_Get_count.
Definition wrapper.cpp:222
void Irecv(void *buf, int count, MPI_Datatype datatype, int source, int tag, MPI_Comm comm, MPI_Request *request)
MPI wrapper for MPI_Irecv.
Definition wrapper.cpp:102
void Probe(int source, int tag, MPI_Comm comm, MPI_Status *status)
MPI wrapper for MPI_Probe.
Definition wrapper.cpp:201
void Allreduce(const void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
MPI wrapper for MPI_Allreduce.
Definition wrapper.cpp:119
void Waitall(int count, MPI_Request array_of_requests[], MPI_Status *array_of_statuses)
MPI wrapper for MPI_Waitall.
Definition wrapper.cpp:187
void Isend(const void *buf, int count, MPI_Datatype datatype, int dest, int tag, MPI_Comm comm, MPI_Request *request)
MPI wrapper for MPI_Isend.
Definition wrapper.cpp:85