Shamrock 2025.10.0
Astrophysical Code
Loading...
Searching...
No Matches
PatchScheduler.hpp
Go to the documentation of this file.
1// -------------------------------------------------------//
2//
3// SHAMROCK code for hydrodynamics
4// Copyright (c) 2021-2026 Timothée David--Cléris <tim.shamrock@proton.me>
5// SPDX-License-Identifier: CeCILL Free Software License Agreement v2.1
6// Shamrock is licensed under the CeCILL 2.1 License, see LICENSE for more information
7//
8// -------------------------------------------------------//
9
10#pragma once
11
22
25#include "shambase/time.hpp"
30#include <nlohmann/json.hpp>
31#include <unordered_set>
32#include <fstream>
33#include <functional>
34#include <memory>
35#include <stdexcept>
36#include <tuple>
37#include <vector>
38// #include "shamrock/scheduler/SerialPatchTree.hpp"
41// #include "shamrock/legacy/patch/patchdata_buffer.hpp"
43#include "shambackends/math.hpp"
51
53 u64 split_load_value = 0_u64;
54 u64 merge_load_value = 0_u64;
55};
56
63inline void to_json(nlohmann::json &j, const PatchSchedulerConfig &p) {
64 j = nlohmann::json{
65 {"split_load_value", p.split_load_value},
66 {"merge_load_value", p.merge_load_value},
67 };
68}
69
76inline void from_json(const nlohmann::json &j, PatchSchedulerConfig &p) {
77 j.at("split_load_value").get_to<u64>(p.split_load_value);
78 j.at("merge_load_value").get_to<u64>(p.merge_load_value);
79}
80
85class PatchScheduler {
86
88
89 public:
90 static constexpr u64 max_axis_patch_coord = LoadBalancer::max_box_sz;
91 static constexpr u64 max_axis_patch_coord_length = LoadBalancer::max_box_sz + 1;
92
93 using PatchTree = shamrock::scheduler::PatchTree;
94 using SchedulerPatchData = shamrock::scheduler::SchedulerPatchData;
95
96 std::shared_ptr<shamrock::patch::PatchDataLayerLayout> pdl_ptr;
97
100
102 SchedulerPatchData patch_data;
103 PatchTree patch_tree;
104
105 // using unordered set is not an issue since we use the find command after
106 std::unordered_set<u64> owned_patch_id;
108
109 inline shamrock::patch::PatchDataLayerLayout &pdl_old() {
110 return shambase::get_check_ref(pdl_ptr);
111 }
112
113 inline std::shared_ptr<shamrock::patch::PatchDataLayerLayout> get_layout_ptr_old() const {
114 return pdl_ptr;
115 }
116
123 void scheduler_step(bool do_split_merge, bool do_load_balancing);
124
125 void init_mpi_required_types();
126
127 void free_mpi_required_types();
128
129 PatchScheduler(
130 const std::shared_ptr<shamrock::patch::PatchDataLayerLayout> &pdl_ptr,
131 u64 crit_split,
132 u64 crit_merge);
133
134 ~PatchScheduler();
135
136 std::string dump_status();
137
138 inline void update_local_load_value(std::function<u64(shamrock::patch::Patch)> load_function) {
139 for (u64 id : owned_patch_id) {
140 shamrock::patch::Patch &p = patch_list.local[patch_list.id_patch_to_local_idx[id]];
141 p.load_value = load_function(p);
142 }
143 patch_list.is_load_values_up_to_date = true;
144 }
145
146 template<class vectype>
147 std::tuple<vectype, vectype> get_box_tranform();
148
149 template<class vectype>
150 std::tuple<vectype, vectype> get_box_volume();
151
152 bool should_resize_box(bool node_in);
153
161 template<class vectype>
162 void set_coord_domain_bound(vectype bmin, vectype bmax) {
163
164 if (!pdl_old().check_main_field_type<vectype>()) {
165 std::invalid_argument(
166 std::string("the main field is not of the correct type to call this function\n")
167 + "fct called : " + __PRETTY_FUNCTION__
168 + "current patch data layout : " + pdl_old().get_description_str());
169 }
170
171 patch_data.sim_box.set_bounding_box<vectype>({bmin, bmax});
172
173 shamlog_debug_ln("PatchScheduler", "box resized to :", bmin, bmax);
174 }
175
183
184 template<u32 dim>
185 void make_patch_base_grid(std::array<u32, dim> patch_count);
186
193 template<class vectype>
194 void set_coord_domain_bound(std::tuple<vectype, vectype> box) {
195 auto [a, b] = box;
197 }
198
199 std::string format_patch_coord(shamrock::patch::Patch p);
200
201 void check_patchdata_locality_correctness();
202
203 [[deprecated]]
204 void dump_local_patches(std::string filename);
205
206 std::vector<std::unique_ptr<shamrock::patch::PatchDataLayer>> gather_data(u32 rank);
207
216 //[[deprecated]]
217 // inline u64 add_patch(shamrock::patch::Patch p, shamrock::patch::PatchData && pdat){
218 // p.id_patch = patch_list._next_patch_id;
219 // patch_list._next_patch_id ++;
220 //
221 // patch_list.global.push_back(p);
222 //
223 // patch_data.owned_data.insert({p.id_patch , pdat});
224 //
225 // return p.id_patch;
226 //}
227
228 void add_root_patch();
229
230 [[deprecated]]
231 void sync_build_LB(bool global_patch_sync, bool balance_load);
232
233 template<class vec>
234 inline shamrock::patch::PatchCoordTransform<vec> get_patch_transform() {
235 return get_sim_box().template get_patch_transform<vec>();
236 }
237
238 // template<class vec>
239 // inline SerialPatchTree<vec> make_serial_ptree(){
240 // return SerialPatchTree<vec>(patch_tree, get_patch_transform<vec>());
241 // }
242
257 template<class Function>
258 inline void for_each_patch_data(Function &&fct) {
259
260 patch_data.for_each_patchdata([&](u64 patch_id, shamrock::patch::PatchDataLayer &pdat) {
262 = patch_list.global[patch_list.id_patch_to_global_idx[patch_id]];
263
264 if (!cur_p.is_err_mode()) {
265 fct(patch_id, cur_p, pdat);
266 }
267 });
268 }
269
270 template<class Function>
271 inline void for_each_patch(Function &&fct) {
272
273 patch_data.for_each_patchdata([&](u64 patch_id, shamrock::patch::PatchDataLayer &pdat) {
276
277 // TODO should feed the sycl queue to the lambda
278 if (!cur_p.is_err_mode()) {
279 fct(patch_id, cur_p);
280 }
281 });
282 }
283
284 inline void for_each_global_patch(
285 const std::function<void(const shamrock::patch::Patch &)> &fct) {
286 for (const shamrock::patch::Patch &p : patch_list.global) {
287 if (!p.is_err_mode()) {
288 fct(p);
289 }
290 }
291 }
292
293 inline void for_each_local_patch(
294 const std::function<void(const shamrock::patch::Patch &)> &fct) {
295 for (const shamrock::patch::Patch &p : patch_list.local) {
296 if (!p.is_err_mode()) {
297 fct(p);
298 }
299 }
300 }
301
302 inline void for_each_local_patchdata(
303 const std::function<void(const shamrock::patch::Patch &, shamrock::patch::PatchDataLayer &)>
304 &fct) {
305 for (const shamrock::patch::Patch &p : patch_list.local) {
306 if (!p.is_err_mode()) {
307 fct(p, patch_data.get_pdat(p.id_patch));
308 }
309 }
310 }
311
312 inline void for_each_local_patch_nonempty(
313 std::function<void(const shamrock::patch::Patch &)> fct) {
314 patch_data.for_each_patchdata([&](u64 patch_id, shamrock::patch::PatchDataLayer &pdat) {
315 shamrock::patch::Patch &cur_p
316 = patch_list.global[patch_list.id_patch_to_global_idx.at(patch_id)];
317
318 if ((!cur_p.is_err_mode()) && (!pdat.is_empty())) {
319 fct(cur_p);
320 }
321 });
322 }
323
324 inline u32 get_patch_rank_owner(u64 patch_id) {
325 shamrock::patch::Patch &cur_p
326 = patch_list.global[patch_list.id_patch_to_global_idx.at(patch_id)];
327 return cur_p.node_owner_id;
328 }
329
330 inline void for_each_patchdata_nonempty(
331 std::function<void(const shamrock::patch::Patch, shamrock::patch::PatchDataLayer &)> fct) {
332 patch_data.for_each_patchdata([&](u64 patch_id, shamrock::patch::PatchDataLayer &pdat) {
333 shamrock::patch::Patch &cur_p
334 = patch_list.global[patch_list.id_patch_to_global_idx.at(patch_id)];
335
336 if ((!cur_p.is_err_mode()) && (!pdat.is_empty())) {
337 fct(cur_p, pdat);
338 }
339 });
340 }
341
342 template<class T>
343 inline shambase::DistributedData<T> map_owned_patchdata(
344 std::function<T(const shamrock::patch::Patch, shamrock::patch::PatchDataLayer &pdat)> fct) {
345 shambase::DistributedData<T> ret;
346
347 using namespace shamrock::patch;
348 for_each_patch_data([&](u64 id_patch, Patch cur_p, PatchDataLayer &pdat) {
349 ret.add_obj(id_patch, fct(cur_p, pdat));
350 });
351
352 return ret;
353 }
354
355 template<class T>
356 inline shambase::DistributedData<T> distrib_data_local_to_all_simple(
357 shambase::DistributedData<T> &src) {
358 using namespace shamrock::patch;
359
360 // TODO : after a split the scheduler patch list state does not match global =
361 // allgather(local) but here it is implicitely assumed, that's ... bad
362 return shamalgs::collective::fetch_all_simple<T, Patch>(
363 src, patch_list.local, patch_list.global, [](Patch p) {
364 return p.id_patch;
365 });
366 }
367
368 template<class T>
369 inline shambase::DistributedData<T> distrib_data_local_to_all_load_store(
370 shambase::DistributedData<T> &src) {
371 using namespace shamrock::patch;
372
373 return shamalgs::collective::fetch_all_storeload<T, Patch>(
374 src, patch_list.local, patch_list.global, [](Patch p) {
375 return p.id_patch;
376 });
377 }
378
379 template<class T>
380 inline shambase::DistributedData<T> map_owned_patchdata_fetch_simple(
381 std::function<T(const shamrock::patch::Patch, shamrock::patch::PatchDataLayer &pdat)> fct) {
382 shambase::DistributedData<T> ret;
383
384 using namespace shamrock::patch;
385 for_each_patch_data([&](u64 id_patch, Patch cur_p, PatchDataLayer &pdat) {
386 ret.add_obj(id_patch, fct(cur_p, pdat));
387 });
388
389 return distrib_data_local_to_all_simple(ret);
390 }
391
392 template<class T>
393 inline shambase::DistributedData<T> map_owned_patchdata_fetch_load_store(
394 std::function<T(const shamrock::patch::Patch, shamrock::patch::PatchDataLayer &pdat)> fct) {
395 shambase::DistributedData<T> ret;
396
397 using namespace shamrock::patch;
398 for_each_patch_data([&](u64 id_patch, Patch cur_p, PatchDataLayer &pdat) {
399 ret.add_obj(id_patch, fct(cur_p, pdat));
400 });
401
402 return distrib_data_local_to_all_load_store(ret);
403 }
404
405 template<class T>
406 inline shamrock::patch::PatchField<T> map_owned_to_patch_field_simple(
407 std::function<T(const shamrock::patch::Patch, shamrock::patch::PatchDataLayer &pdat)> fct) {
408 return shamrock::patch::PatchField<T>(map_owned_patchdata_fetch_simple(fct));
409 }
410
411 template<class T>
412 inline shamrock::patch::PatchField<T> map_owned_to_patch_field_load_store(
413 std::function<T(const shamrock::patch::Patch, shamrock::patch::PatchDataLayer &pdat)> fct) {
414 return shamrock::patch::PatchField<T>(map_owned_patchdata_fetch_load_store(fct));
415 }
416
417 inline u64 get_rank_count() {
418 StackEntry stack_loc{};
419 using namespace shamrock::patch;
420 u64 num_obj = 0; // TODO get_rank_count() in scheduler
421 for_each_patch_data([&](u64 id_patch, Patch cur_p, PatchDataLayer &pdat) {
422 num_obj += pdat.get_obj_cnt();
423 });
424
425 return num_obj;
426 }
427
428 inline u64 get_total_obj_count() {
429 StackEntry stack_loc{};
430 u64 part_cnt = get_rank_count();
431 return shamalgs::collective::allreduce_sum(part_cnt);
432 }
433
434 template<class T>
435 inline std::unique_ptr<sycl::buffer<T>> rankgather_field(u32 field_idx) {
436 StackEntry stack_loc{};
437 std::unique_ptr<sycl::buffer<T>> ret;
438
439 auto fd = pdl_old().get_field<T>(field_idx);
440 u64 nvar = fd.nvar;
441
442 u64 num_obj = get_rank_count();
443
444 if (num_obj > 0) {
445 ret = std::make_unique<sycl::buffer<T>>(num_obj * nvar);
446
447 using namespace shamrock::patch;
448
449 u64 ptr = 0; // TODO accumulate_field() in scheduler ?
450 for_each_patch_data([&](u64 id_patch, Patch cur_p, PatchDataLayer &pdat) {
451 using namespace shamalgs::memory;
452 using namespace shambase;
453
454 if (pdat.get_obj_cnt() > 0) {
455 write_with_offset_into(
456 shamsys::instance::get_compute_scheduler().get_queue(),
457 get_check_ref(ret),
458 pdat.get_field<T>(field_idx).get_buf(),
459 ptr,
460 pdat.get_obj_cnt() * nvar);
461
462 ptr += pdat.get_obj_cnt() * nvar;
463 }
464 });
465 }
466
467 return ret;
468 }
469
470 // template<class Function, class Pfield>
471 // inline void compute_patch_field(Pfield & field, MPI_Datatype & dtype , Function && lambda){
472 // field.local_nodes_value.resize(patch_list.local.size());
473 //
474 //
475 //
476 // for (u64 idx = 0; idx < patch_list.local.size(); idx++) {
477 //
478 // Patch &cur_p = patch_list.local[idx];
479 //
480 // PatchDataBuffer pdatbuf =
481 // attach_to_patchData(patch_data.owned_data.at(cur_p.id_patch));
482 //
483 // field.local_nodes_value[idx] =
484 // lambda(shamsys::instance::get_compute_queue(),cur_p,pdatbuf);
485 //
486 // }
487 //
488 // field.build_global(dtype);
489 //
490 // }
491
492 template<class Function, class Pfield>
493 inline void compute_patch_field(Pfield &field, MPI_Datatype &dtype, Function &&lambda) {
494 field.local_nodes_value.resize(patch_list.local.size());
495
496 for (u64 idx = 0; idx < patch_list.local.size(); idx++) {
497
498 shamrock::patch::Patch &cur_p = patch_list.local[idx];
499
500 if (!cur_p.is_err_mode()) {
501 field.local_nodes_value[idx] = lambda(
502 shamsys::instance::get_compute_queue(),
503 cur_p,
504 patch_data.owned_data.get(cur_p.id_patch));
505 }
506 }
507
508 field.build_global(dtype);
509 }
510
511 inline auto get_node_set_edge_patchdata_layer_refs() {
512 shamrock::solvergraph::NodeSetEdge<shamrock::solvergraph::PatchDataLayerRefs> node_set_edge(
513 [&](shamrock::solvergraph::PatchDataLayerRefs &edge) {
514 edge.free_alloc();
515 using namespace shamrock::patch;
516 for_each_patchdata_nonempty([&](Patch cur_p, PatchDataLayer &pdat) {
517 edge.patchdatas.add_obj(cur_p.id_patch, std::ref(pdat));
518 });
519 });
520
521 return std::make_shared<decltype(node_set_edge)>(std::move(node_set_edge));
522 };
523
530 std::vector<u64> add_root_patches(std::vector<shamrock::patch::PatchCoord<3>> coords);
531
532 shamrock::patch::SimulationBoxInfo &get_sim_box() { return patch_data.sim_box; }
533
534 nlohmann::json serialize_patch_metadata();
535
536 private:
537 void split_patches(std::unordered_set<u64> split_rq);
538 void merge_patches(std::unordered_set<u64> merge_rq);
539
540 void set_patch_pack_values(std::unordered_set<u64> merge_rq);
541};
function to run load balancing with the hilbert curve
Node that applies a custom function to modify connected edges.
Defines the PatchDataLayerRefs class for managing distributed references to patch data layers.
void to_json(nlohmann::json &j, const PatchSchedulerConfig &p)
Converts a PatchSchedulerConfig object to a JSON object.
void from_json(const nlohmann::json &j, PatchSchedulerConfig &p)
Deserializes a PatchSchedulerConfig object from a JSON object.
Header file for the patch struct and related function.
PatchData handling.
std::uint32_t u32
32 bit unsigned integer
std::uint64_t u64
64 bit unsigned integer
void set_coord_domain_bound(std::tuple< vectype, vectype > box)
modify the bounding box of the patch domain
void for_each_patch_data(Function &&fct)
for each macro for patchadata example usage
SchedulerPatchData patch_data
handle the data of the patches of the scheduler
u64 crit_patch_split
splitting limit (if load value > crit_patch_split => patch split)
PatchTree patch_tree
handle the tree structure of the patches
void scheduler_step(bool do_split_merge, bool do_load_balancing)
scheduler step
void set_coord_domain_bound(vectype bmin, vectype bmax)
modify the bounding box of the patch domain
SchedulerPatchList patch_list
handle the list of the patches of the scheduler
std::unordered_set< u64 > owned_patch_id
(owned_patch_id = patch_list.build_local())
std::vector< u64 > add_root_patches(std::vector< shamrock::patch::PatchCoord< 3 > > coords)
add a root patch to the scheduler
u64 crit_patch_merge
merging limit (if load value < crit_patch_merge => patch merge)
void allpush_data(shamrock::patch::PatchDataLayer &pdat)
push data in the scheduler The content of pdat as to be the same for each node
void add_root_patch()
add patch to the scheduler
Handle the patch list of the mpi scheduler.
std::vector< shamrock::patch::Patch > global
contain the list of all patches in the simulation
std::unordered_map< u64, u64 > id_patch_to_global_idx
id_patch_to_global_idx[patch_id] = index in global patch list
iterator add_obj(u64 id, T &&obj)
Adds a new object to the collection.
PatchDataLayer container class, the layout is described in patchdata_layout.
Patch Tree : Tree structure organisation for an abstract list of patches Nb : this tree is compatible...
Definition PatchTree.hpp:29
Class to handle PatchData owned by the node.
virtual void free_alloc() override
Free allocated memory.
T & get_check_ref(const std::unique_ptr< T > &ptr, SourceLocation loc=SourceLocation())
Takes a std::unique_ptr and returns a reference to the object it holds. It throws a std::runtime_erro...
Definition memory.hpp:110
header for PatchData related function and declaration
Class to handle the patch list of the mpi scheduler.
This file contains the definition for the stacktrace related functionality.
shambase::details::BasicStackEntry StackEntry
Alias for shambase::details::BasicStackEntry.
Patch object that contain generic patch information.
Definition Patch.hpp:33
bool is_err_mode() const
check if a patch is in error mode
Definition Patch.hpp:119
u32 node_owner_id
node rank owner of this patch
Definition Patch.hpp:93
u64 id_patch
unique key that identify the patch
Definition Patch.hpp:86
header file to manage sycl