96 std::shared_ptr<shamrock::patch::PatchDataLayerLayout> pdl_ptr;
113 inline std::shared_ptr<shamrock::patch::PatchDataLayerLayout> get_layout_ptr_old()
const {
125 void init_mpi_required_types();
127 void free_mpi_required_types();
130 const std::shared_ptr<shamrock::patch::PatchDataLayerLayout> &pdl_ptr,
136 std::string dump_status();
138 inline void update_local_load_value(std::function<
u64(shamrock::patch::Patch)> load_function) {
141 p.load_value = load_function(p);
146 template<
class vectype>
147 std::tuple<vectype, vectype> get_box_tranform();
149 template<
class vectype>
150 std::tuple<vectype, vectype> get_box_volume();
152 bool should_resize_box(
bool node_in);
161 template<
class vectype>
164 if (!pdl_old().check_main_field_type<vectype>()) {
165 std::invalid_argument(
166 std::string(
"the main field is not of the correct type to call this function\n")
167 +
"fct called : " + __PRETTY_FUNCTION__
168 +
"current patch data layout : " + pdl_old().get_description_str());
171 patch_data.sim_box.set_bounding_box<vectype>({bmin, bmax});
173 shamlog_debug_ln(
"PatchScheduler",
"box resized to :", bmin, bmax);
185 void make_patch_base_grid(std::array<u32, dim> patch_count);
193 template<
class vectype>
201 void check_patchdata_locality_correctness();
204 void dump_local_patches(std::string filename);
206 std::vector<std::unique_ptr<shamrock::patch::PatchDataLayer>> gather_data(
u32 rank);
231 void sync_build_LB(
bool global_patch_sync,
bool balance_load);
235 return get_sim_box().template get_patch_transform<vec>();
257 template<
class Function>
265 fct(patch_id, cur_p, pdat);
270 template<
class Function>
271 inline void for_each_patch(Function &&fct) {
279 fct(patch_id, cur_p);
284 inline void for_each_global_patch(
285 const std::function<
void(
const shamrock::patch::Patch &)> &fct) {
286 for (
const shamrock::patch::Patch &p :
patch_list.global) {
287 if (!p.is_err_mode()) {
293 inline void for_each_local_patch(
294 const std::function<
void(
const shamrock::patch::Patch &)> &fct) {
295 for (
const shamrock::patch::Patch &p :
patch_list.local) {
296 if (!p.is_err_mode()) {
302 inline void for_each_local_patchdata(
303 const std::function<
void(
const shamrock::patch::Patch &, shamrock::patch::PatchDataLayer &)>
305 for (
const shamrock::patch::Patch &p :
patch_list.local) {
306 if (!p.is_err_mode()) {
312 inline void for_each_local_patch_nonempty(
313 std::function<
void(
const shamrock::patch::Patch &)> fct) {
314 patch_data.for_each_patchdata([&](
u64 patch_id, shamrock::patch::PatchDataLayer &pdat) {
315 shamrock::patch::Patch &cur_p
318 if ((!cur_p.
is_err_mode()) && (!pdat.is_empty())) {
324 inline u32 get_patch_rank_owner(
u64 patch_id) {
325 shamrock::patch::Patch &cur_p
330 inline void for_each_patchdata_nonempty(
331 std::function<
void(
const shamrock::patch::Patch, shamrock::patch::PatchDataLayer &)> fct) {
332 patch_data.for_each_patchdata([&](
u64 patch_id, shamrock::patch::PatchDataLayer &pdat) {
333 shamrock::patch::Patch &cur_p
336 if ((!cur_p.
is_err_mode()) && (!pdat.is_empty())) {
343 inline shambase::DistributedData<T> map_owned_patchdata(
344 std::function<T(
const shamrock::patch::Patch, shamrock::patch::PatchDataLayer &pdat)> fct) {
345 shambase::DistributedData<T> ret;
347 using namespace shamrock::patch;
349 ret.
add_obj(id_patch, fct(cur_p, pdat));
356 inline shambase::DistributedData<T> distrib_data_local_to_all_simple(
357 shambase::DistributedData<T> &src) {
358 using namespace shamrock::patch;
362 return shamalgs::collective::fetch_all_simple<T, Patch>(
369 inline shambase::DistributedData<T> distrib_data_local_to_all_load_store(
370 shambase::DistributedData<T> &src) {
371 using namespace shamrock::patch;
373 return shamalgs::collective::fetch_all_storeload<T, Patch>(
380 inline shambase::DistributedData<T> map_owned_patchdata_fetch_simple(
381 std::function<T(
const shamrock::patch::Patch, shamrock::patch::PatchDataLayer &pdat)> fct) {
382 shambase::DistributedData<T> ret;
384 using namespace shamrock::patch;
386 ret.
add_obj(id_patch, fct(cur_p, pdat));
389 return distrib_data_local_to_all_simple(ret);
393 inline shambase::DistributedData<T> map_owned_patchdata_fetch_load_store(
394 std::function<T(
const shamrock::patch::Patch, shamrock::patch::PatchDataLayer &pdat)> fct) {
395 shambase::DistributedData<T> ret;
397 using namespace shamrock::patch;
399 ret.
add_obj(id_patch, fct(cur_p, pdat));
402 return distrib_data_local_to_all_load_store(ret);
406 inline shamrock::patch::PatchField<T> map_owned_to_patch_field_simple(
407 std::function<T(
const shamrock::patch::Patch, shamrock::patch::PatchDataLayer &pdat)> fct) {
408 return shamrock::patch::PatchField<T>(map_owned_patchdata_fetch_simple(fct));
412 inline shamrock::patch::PatchField<T> map_owned_to_patch_field_load_store(
413 std::function<T(
const shamrock::patch::Patch, shamrock::patch::PatchDataLayer &pdat)> fct) {
414 return shamrock::patch::PatchField<T>(map_owned_patchdata_fetch_load_store(fct));
417 inline u64 get_rank_count() {
419 using namespace shamrock::patch;
422 num_obj += pdat.get_obj_cnt();
428 inline u64 get_total_obj_count() {
430 u64 part_cnt = get_rank_count();
431 return shamalgs::collective::allreduce_sum(part_cnt);
435 inline std::unique_ptr<sycl::buffer<T>> rankgather_field(
u32 field_idx) {
437 std::unique_ptr<sycl::buffer<T>> ret;
439 auto fd = pdl_old().get_field<T>(field_idx);
442 u64 num_obj = get_rank_count();
445 ret = std::make_unique<sycl::buffer<T>>(num_obj * nvar);
447 using namespace shamrock::patch;
451 using namespace shamalgs::memory;
452 using namespace shambase;
454 if (pdat.get_obj_cnt() > 0) {
455 write_with_offset_into(
456 shamsys::instance::get_compute_scheduler().get_queue(),
458 pdat.get_field<T>(field_idx).get_buf(),
460 pdat.get_obj_cnt() * nvar);
462 ptr += pdat.get_obj_cnt() * nvar;
492 template<
class Function,
class Pfield>
493 inline void compute_patch_field(Pfield &field, MPI_Datatype &dtype, Function &&lambda) {
494 field.local_nodes_value.resize(
patch_list.local.size());
498 shamrock::patch::Patch &cur_p =
patch_list.local[idx];
501 field.local_nodes_value[idx] = lambda(
502 shamsys::instance::get_compute_queue(),
508 field.build_global(dtype);
511 inline auto get_node_set_edge_patchdata_layer_refs() {
512 shamrock::solvergraph::NodeSetEdge<shamrock::solvergraph::PatchDataLayerRefs> node_set_edge(
513 [&](shamrock::solvergraph::PatchDataLayerRefs &edge) {
515 using namespace shamrock::patch;
516 for_each_patchdata_nonempty([&](Patch cur_p, PatchDataLayer &pdat) {
521 return std::make_shared<decltype(node_set_edge)>(std::move(node_set_edge));
530 std::vector<u64>
add_root_patches(std::vector<shamrock::patch::PatchCoord<3>> coords);
532 shamrock::patch::SimulationBoxInfo &get_sim_box() {
return patch_data.sim_box; }
534 nlohmann::json serialize_patch_metadata();
537 void split_patches(std::unordered_set<u64> split_rq);
538 void merge_patches(std::unordered_set<u64> merge_rq);
540 void set_patch_pack_values(std::unordered_set<u64> merge_rq);