Shamrock 2025.10.0
Astrophysical Code
Loading...
Searching...
No Matches
PatchScheduler.hpp
Go to the documentation of this file.
1// -------------------------------------------------------//
2//
3// SHAMROCK code for hydrodynamics
4// Copyright (c) 2021-2026 Timothée David--Cléris <tim.shamrock@proton.me>
5// SPDX-License-Identifier: CeCILL Free Software License Agreement v2.1
6// Shamrock is licensed under the CeCILL 2.1 License, see LICENSE for more information
7//
8// -------------------------------------------------------//
9
10#pragma once
11
25#include "shambase/time.hpp"
30#include <nlohmann/json.hpp>
31#include <unordered_set>
32#include <fstream>
33#include <functional>
34#include <memory>
35#include <stdexcept>
36#include <tuple>
37#include <vector>
38// #include "shamrock/scheduler/SerialPatchTree.hpp"
41// #include "shamrock/legacy/patch/patchdata_buffer.hpp"
43#include "shambackends/math.hpp"
51
53 u64 split_load_value = 0_u64;
54 u64 merge_load_value = 0_u64;
55};
56
63inline void to_json(nlohmann::json &j, const PatchSchedulerConfig &p) {
64 j = nlohmann::json{
65 {"split_load_value", p.split_load_value},
66 {"merge_load_value", p.merge_load_value},
67 };
68}
69
76inline void from_json(const nlohmann::json &j, PatchSchedulerConfig &p) {
77 j.at("split_load_value").get_to<u64>(p.split_load_value);
78 j.at("merge_load_value").get_to<u64>(p.merge_load_value);
79}
80
86
88
89 public:
90 static constexpr u64 max_axis_patch_coord = LoadBalancer::max_box_sz;
91 static constexpr u64 max_axis_patch_coord_length = LoadBalancer::max_box_sz + 1;
92
95
96 std::shared_ptr<shamrock::patch::PatchDataLayerLayout> pdl_ptr;
97
100
104
105 // using unordered set is not an issue since we use the find command after
106 std::unordered_set<u64> owned_patch_id;
108
109 inline shamrock::patch::PatchDataLayerLayout &pdl_old() {
110 return shambase::get_check_ref(pdl_ptr);
111 }
112
113 inline std::shared_ptr<shamrock::patch::PatchDataLayerLayout> get_layout_ptr_old() const {
114 return pdl_ptr;
115 }
116
123 void scheduler_step(bool do_split_merge, bool do_load_balancing);
124
125 void init_mpi_required_types();
126
127 void free_mpi_required_types();
128
130 const std::shared_ptr<shamrock::patch::PatchDataLayerLayout> &pdl_ptr,
131 u64 crit_split,
132 u64 crit_merge);
133
135
136 std::string dump_status();
137
138 inline void update_local_load_value(std::function<u64(shamrock::patch::Patch)> load_function) {
139 for (u64 id : owned_patch_id) {
141 p.load_value = load_function(p);
142 }
144 }
145
146 template<class vectype>
147 std::tuple<vectype, vectype> get_box_tranform();
148
149 template<class vectype>
150 std::tuple<vectype, vectype> get_box_volume();
151
152 bool should_resize_box(bool node_in);
153
161 template<class vectype>
162 void set_coord_domain_bound(vectype bmin, vectype bmax) {
163
164 if (!pdl_old().check_main_field_type<vectype>()) {
165 std::invalid_argument(
166 std::string("the main field is not of the correct type to call this function\n")
167 + "fct called : " + __PRETTY_FUNCTION__
168 + "current patch data layout : " + pdl_old().get_description_str());
169 }
170
171 patch_data.sim_box.set_bounding_box<vectype>({bmin, bmax});
172
173 shamlog_debug_ln("PatchScheduler", "box resized to :", bmin, bmax);
174 }
175
183
184 template<u32 dim>
185 void make_patch_base_grid(std::array<u32, dim> patch_count);
186
193 template<class vectype>
194 void set_coord_domain_bound(std::tuple<vectype, vectype> box) {
195 auto [a, b] = box;
197 }
198
199 std::string format_patch_coord(shamrock::patch::Patch p);
200
201 void check_patchdata_locality_correctness();
202
203 [[deprecated]]
204 void dump_local_patches(std::string filename);
205
206 std::vector<std::unique_ptr<shamrock::patch::PatchDataLayer>> gather_data(u32 rank);
207
216 //[[deprecated]]
217 // inline u64 add_patch(shamrock::patch::Patch p, shamrock::patch::PatchData && pdat){
218 // p.id_patch = patch_list._next_patch_id;
219 // patch_list._next_patch_id ++;
220 //
221 // patch_list.global.push_back(p);
222 //
223 // patch_data.owned_data.insert({p.id_patch , pdat});
224 //
225 // return p.id_patch;
226 //}
227
228 void add_root_patch();
229
230 [[deprecated]]
231 void sync_build_LB(bool global_patch_sync, bool balance_load);
232
233 template<class vec>
234 inline shamrock::patch::PatchCoordTransform<vec> get_patch_transform() {
235 return get_sim_box().template get_patch_transform<vec>();
236 }
237
238 // template<class vec>
239 // inline SerialPatchTree<vec> make_serial_ptree(){
240 // return SerialPatchTree<vec>(patch_tree, get_patch_transform<vec>());
241 // }
242
257 template<class Function>
258 inline void for_each_patch_data(Function &&fct) {
259
260 patch_data.for_each_patchdata([&](u64 patch_id, shamrock::patch::PatchDataLayer &pdat) {
263
264 if (!cur_p.is_err_mode()) {
265 fct(patch_id, cur_p, pdat);
266 }
267 });
268 }
269
270 template<class Function>
271 inline void for_each_patch(Function &&fct) {
272
273 patch_data.for_each_patchdata([&](u64 patch_id, shamrock::patch::PatchDataLayer &pdat) {
276
277 // TODO should feed the sycl queue to the lambda
278 if (!cur_p.is_err_mode()) {
279 fct(patch_id, cur_p);
280 }
281 });
282 }
283
284 inline void for_each_global_patch(
285 const std::function<void(const shamrock::patch::Patch &)> &fct) {
286 for (const shamrock::patch::Patch &p : patch_list.global) {
287 if (!p.is_err_mode()) {
288 fct(p);
289 }
290 }
291 }
292
293 inline void for_each_local_patch(
294 const std::function<void(const shamrock::patch::Patch &)> &fct) {
295 for (const shamrock::patch::Patch &p : patch_list.local) {
296 if (!p.is_err_mode()) {
297 fct(p);
298 }
299 }
300 }
301
302 inline void for_each_local_patchdata(
303 const std::function<void(const shamrock::patch::Patch &, shamrock::patch::PatchDataLayer &)>
304 &fct) {
305 for (const shamrock::patch::Patch &p : patch_list.local) {
306 if (!p.is_err_mode()) {
307 fct(p, patch_data.get_pdat(p.id_patch));
308 }
309 }
310 }
311
312 inline void for_each_local_patch_nonempty(
313 std::function<void(const shamrock::patch::Patch &)> fct) {
314 patch_data.for_each_patchdata([&](u64 patch_id, shamrock::patch::PatchDataLayer &pdat) {
317
318 if ((!cur_p.is_err_mode()) && (!pdat.is_empty())) {
319 fct(cur_p);
320 }
321 });
322 }
323
324 inline u32 get_patch_rank_owner(u64 patch_id) {
327 return cur_p.node_owner_id;
328 }
329
330 inline void for_each_patchdata_nonempty(
331 std::function<void(const shamrock::patch::Patch, shamrock::patch::PatchDataLayer &)> fct) {
332 patch_data.for_each_patchdata([&](u64 patch_id, shamrock::patch::PatchDataLayer &pdat) {
335
336 if ((!cur_p.is_err_mode()) && (!pdat.is_empty())) {
337 fct(cur_p, pdat);
338 }
339 });
340 }
341
342 template<class T>
343 inline shambase::DistributedData<T> map_owned_patchdata(
344 std::function<T(const shamrock::patch::Patch, shamrock::patch::PatchDataLayer &pdat)> fct) {
346
347 using namespace shamrock::patch;
348 for_each_patch_data([&](u64 id_patch, Patch cur_p, PatchDataLayer &pdat) {
349 ret.add_obj(id_patch, fct(cur_p, pdat));
350 });
351
352 return ret;
353 }
354
355 template<class T>
356 inline shambase::DistributedData<T> distrib_data_local_to_all_simple(
358 using namespace shamrock::patch;
359
360 // TODO : after a split the scheduler patch list state does not match global =
361 // allgather(local) but here it is implicitely assumed, that's ... bad
362 return shamalgs::collective::fetch_all_simple<T, Patch>(
363 src, patch_list.local, patch_list.global, [](Patch p) {
364 return p.id_patch;
365 });
366 }
367
368 template<class T>
369 inline shambase::DistributedData<T> distrib_data_local_to_all_load_store(
371 using namespace shamrock::patch;
372
373 return shamalgs::collective::fetch_all_storeload<T, Patch>(
374 src, patch_list.local, patch_list.global, [](Patch p) {
375 return p.id_patch;
376 });
377 }
378
379 template<class T>
380 inline shambase::DistributedData<T> map_owned_patchdata_fetch_simple(
381 std::function<T(const shamrock::patch::Patch, shamrock::patch::PatchDataLayer &pdat)> fct) {
383
384 using namespace shamrock::patch;
385 for_each_patch_data([&](u64 id_patch, Patch cur_p, PatchDataLayer &pdat) {
386 ret.add_obj(id_patch, fct(cur_p, pdat));
387 });
388
389 return distrib_data_local_to_all_simple(ret);
390 }
391
392 template<class T>
393 inline shambase::DistributedData<T> map_owned_patchdata_fetch_load_store(
394 std::function<T(const shamrock::patch::Patch, shamrock::patch::PatchDataLayer &pdat)> fct) {
396
397 using namespace shamrock::patch;
398 for_each_patch_data([&](u64 id_patch, Patch cur_p, PatchDataLayer &pdat) {
399 ret.add_obj(id_patch, fct(cur_p, pdat));
400 });
401
402 return distrib_data_local_to_all_load_store(ret);
403 }
404
405 template<class T>
406 inline shamrock::patch::PatchField<T> map_owned_to_patch_field_simple(
407 std::function<T(const shamrock::patch::Patch, shamrock::patch::PatchDataLayer &pdat)> fct) {
408 return shamrock::patch::PatchField<T>(map_owned_patchdata_fetch_simple(fct));
409 }
410
411 template<class T>
412 inline shamrock::patch::PatchField<T> map_owned_to_patch_field_load_store(
413 std::function<T(const shamrock::patch::Patch, shamrock::patch::PatchDataLayer &pdat)> fct) {
414 return shamrock::patch::PatchField<T>(map_owned_patchdata_fetch_load_store(fct));
415 }
416
417 inline u64 get_rank_count() {
418 StackEntry stack_loc{};
419 using namespace shamrock::patch;
420 u64 num_obj = 0; // TODO get_rank_count() in scheduler
421 for_each_patch_data([&](u64 id_patch, Patch cur_p, PatchDataLayer &pdat) {
422 num_obj += pdat.get_obj_cnt();
423 });
424
425 return num_obj;
426 }
427
428 inline u64 get_total_obj_count() {
429 StackEntry stack_loc{};
430 u64 part_cnt = get_rank_count();
431 return shamalgs::collective::allreduce_sum(part_cnt);
432 }
433
434 template<class T>
435 inline std::unique_ptr<sycl::buffer<T>> rankgather_field(u32 field_idx) {
436 StackEntry stack_loc{};
437 std::unique_ptr<sycl::buffer<T>> ret;
438
439 auto fd = pdl_old().get_field<T>(field_idx);
440 u64 nvar = fd.nvar;
441
442 u64 num_obj = get_rank_count();
443
444 if (num_obj > 0) {
445 ret = std::make_unique<sycl::buffer<T>>(num_obj * nvar);
446
447 using namespace shamrock::patch;
448
449 u64 ptr = 0; // TODO accumulate_field() in scheduler ?
450 for_each_patch_data([&](u64 id_patch, Patch cur_p, PatchDataLayer &pdat) {
451 using namespace shamalgs::memory;
452 using namespace shambase;
453
454 if (pdat.get_obj_cnt() > 0) {
455 write_with_offset_into(
456 shamsys::instance::get_compute_scheduler().get_queue(),
457 get_check_ref(ret),
458 pdat.get_field<T>(field_idx).get_buf(),
459 ptr,
460 pdat.get_obj_cnt() * nvar);
461
462 ptr += pdat.get_obj_cnt() * nvar;
463 }
464 });
465 }
466
467 return ret;
468 }
469
470 // template<class Function, class Pfield>
471 // inline void compute_patch_field(Pfield & field, MPI_Datatype & dtype , Function && lambda){
472 // field.local_nodes_value.resize(patch_list.local.size());
473 //
474 //
475 //
476 // for (u64 idx = 0; idx < patch_list.local.size(); idx++) {
477 //
478 // Patch &cur_p = patch_list.local[idx];
479 //
480 // PatchDataBuffer pdatbuf =
481 // attach_to_patchData(patch_data.owned_data.at(cur_p.id_patch));
482 //
483 // field.local_nodes_value[idx] =
484 // lambda(shamsys::instance::get_compute_queue(),cur_p,pdatbuf);
485 //
486 // }
487 //
488 // field.build_global(dtype);
489 //
490 // }
491
492 template<class Function, class Pfield>
493 inline void compute_patch_field(Pfield &field, MPI_Datatype &dtype, Function &&lambda) {
494 field.local_nodes_value.resize(patch_list.local.size());
495
496 for (u64 idx = 0; idx < patch_list.local.size(); idx++) {
497
499
500 if (!cur_p.is_err_mode()) {
501 field.local_nodes_value[idx] = lambda(
502 shamsys::instance::get_compute_queue(),
503 cur_p,
504 patch_data.owned_data.get(cur_p.id_patch));
505 }
506 }
507
508 field.build_global(dtype);
509 }
510
511 inline auto get_node_set_edge_patchdata_layer_refs() {
514 edge.free_alloc();
515 using namespace shamrock::patch;
516 for_each_patchdata_nonempty([&](Patch cur_p, PatchDataLayer &pdat) {
517 edge.patchdatas.add_obj(cur_p.id_patch, std::ref(pdat));
518 });
519 });
520
521 return std::make_shared<decltype(node_set_edge)>(std::move(node_set_edge));
522 };
523
530 std::vector<u64> add_root_patches(std::vector<shamrock::patch::PatchCoord<3>> coords);
531
532 shamrock::patch::SimulationBoxInfo &get_sim_box() { return patch_data.sim_box; }
533
534 nlohmann::json serialize_patch_metadata();
535
536 private:
537 void split_patches(std::unordered_set<u64> split_rq);
538 void merge_patches(std::unordered_set<u64> merge_rq);
539
540 void set_patch_pack_values(std::unordered_set<u64> merge_rq);
541};
function to run load balancing with the hilbert curve
Node that applies a custom function to modify connected edges.
Defines the PatchDataLayerRefs class for managing distributed references to patch data layers.
void to_json(nlohmann::json &j, const PatchSchedulerConfig &p)
Converts a PatchSchedulerConfig object to a JSON object.
void from_json(const nlohmann::json &j, PatchSchedulerConfig &p)
Deserializes a PatchSchedulerConfig object from a JSON object.
Header file for the patch struct and related function.
PatchData handling.
std::uint32_t u32
32 bit unsigned integer
std::uint64_t u64
64 bit unsigned integer
The MPI scheduler.
void set_coord_domain_bound(std::tuple< vectype, vectype > box)
modify the bounding box of the patch domain
void for_each_patch_data(Function &&fct)
for each macro for patchadata example usage
SchedulerPatchData patch_data
handle the data of the patches of the scheduler
u64 crit_patch_split
splitting limit (if load value > crit_patch_split => patch split)
PatchTree patch_tree
handle the tree structure of the patches
void scheduler_step(bool do_split_merge, bool do_load_balancing)
scheduler step
void set_coord_domain_bound(vectype bmin, vectype bmax)
modify the bounding box of the patch domain
SchedulerPatchList patch_list
handle the list of the patches of the scheduler
std::unordered_set< u64 > owned_patch_id
(owned_patch_id = patch_list.build_local())
std::vector< u64 > add_root_patches(std::vector< shamrock::patch::PatchCoord< 3 > > coords)
add a root patch to the scheduler
u64 crit_patch_merge
merging limit (if load value < crit_patch_merge => patch merge)
void allpush_data(shamrock::patch::PatchDataLayer &pdat)
push data in the scheduler The content of pdat as to be the same for each node
void add_root_patch()
add patch to the scheduler
Handle the patch list of the mpi scheduler.
std::vector< shamrock::patch::Patch > local
contain the list of patch owned by the current node
bool is_load_values_up_to_date
Are patch load values up to date.
std::unordered_map< u64, u64 > id_patch_to_local_idx
id_patch_to_local_idx[patch_id] = index in local patch list
std::vector< shamrock::patch::Patch > global
contain the list of all patches in the simulation
std::unordered_map< u64, u64 > id_patch_to_global_idx
id_patch_to_global_idx[patch_id] = index in global patch list
Represents a collection of objects distributed across patches identified by a u64 id.
iterator add_obj(u64 id, T &&obj)
Adds a new object to the collection.
FieldDescriptor< T > get_field(const std::string &field_name)
Get the field description id if matching name & type.
PatchDataLayer container class, the layout is described in patchdata_layout.
Store the information related to the size of the simulation box to convert patch integer coordinates ...
Definition SimBox.hpp:35
void set_bounding_box(shammath::CoordRange< T > new_box)
Override the stored bounding box by the one given in new_box.
Definition SimBox.hpp:272
static constexpr u64 max_box_sz
maximal value along an axis for the patch coordinate
Patch Tree : Tree structure organisation for an abstract list of patches Nb : this tree is compatible...
Definition PatchTree.hpp:29
Class to handle PatchData owned by the node.
shamrock::patch::SimulationBoxInfo sim_box
simulation box geometry info
shambase::DistributedData< PatchData > owned_data
map container for patchdata owned by the current node (layout : id_patch,data)
A node that applies a custom function to modify connected edges.
virtual void free_alloc() override
Free allocated memory.
memory manipulation algorithms
namespace for basic c++ utilities
T & get_check_ref(const std::unique_ptr< T > &ptr, SourceLocation loc=SourceLocation())
Takes a std::unique_ptr and returns a reference to the object it holds. It throws a std::runtime_erro...
Definition memory.hpp:110
header for PatchData related function and declaration
Class to handle the patch list of the mpi scheduler.
This file contains the definition for the stacktrace related functionality.
Patch object that contain generic patch information.
Definition Patch.hpp:33
bool is_err_mode() const
check if a patch is in error mode
Definition Patch.hpp:119
u32 node_owner_id
node rank owner of this patch
Definition Patch.hpp:93
u64 id_patch
unique key that identify the patch
Definition Patch.hpp:86
header file to manage sycl