Shamrock 2025.10.0
Astrophysical Code
Loading...
Searching...
No Matches
DataInserterUtility.hpp
Go to the documentation of this file.
1// -------------------------------------------------------//
2//
3// SHAMROCK code for hydrodynamics
4// Copyright (c) 2021-2026 Timothée David--Cléris <tim.shamrock@proton.me>
5// SPDX-License-Identifier: CeCILL Free Software License Agreement v2.1
6// Shamrock is licensed under the CeCILL 2.1 License, see LICENSE for more information
7//
8// -------------------------------------------------------//
9
10#pragma once
11
23#include <mpi.h>
24
25namespace shamrock {
26
38 PatchScheduler &sched;
39
40 public:
46 DataInserterUtility(PatchScheduler &sched) : sched(sched) {}
47
48 inline void balance_load(std::function<void(void)> load_balance_update) {
49 // it seems that we need multiple runs to converge the load balance
51 logger::info_ln(
52 "DataInserterUtility", "---------------------------------------------"));
53
54 u64 npatch_last = sched.patch_list.global.size();
55
56 u32 max_runs = 30;
57
58 int i = 0;
59 for (; i < max_runs; i++) {
60 if (shamcomm::world_rank() == 0) {
61 logger::info_ln("DataInserterUtility", "Compute load ...");
62 }
63
64 load_balance_update();
65
66 shamcomm::mpi::Barrier(MPI_COMM_WORLD);
67
68 if (shamcomm::world_rank() == 0) {
69 logger::info_ln("DataInserterUtility", "run scheduler step ...");
70 }
71
72 sched.scheduler_step(false, false);
73 sched.scheduler_step(true, true);
74
75 u64 npatch_new = sched.patch_list.global.size();
76 if (npatch_new != npatch_last) {
77 npatch_last = npatch_new;
78 } else {
79 break;
80 }
81 }
82
84 logger::info_ln(
85 "DataInserterUtility",
86 "patch count stable after",
87 i + 1,
88 "runs npatch =",
89 npatch_last));
90
92 logger::info_ln(
93 "DataInserterUtility", "---------------------------------------------"));
94 }
95
111 template<class Tvec>
114 std::string main_field_name,
115 u32 split_threshold,
116 std::function<void(void)> load_balance_update) {
117 using namespace shamrock::patch;
118
119 u64 pdat_ob_cnt = pdat_ins.get_obj_cnt();
120
121 u64 sum_push = shamalgs::collective::allreduce_sum(pdat_ob_cnt);
122 if (shamcomm::world_rank() == 0) {
123 logger::info_ln("DataInserterUtility", "pushing data in scheduler, N =", sum_push);
124 }
125
126 if (pdat_ob_cnt < split_threshold) {
127 bool should_insert = true;
128 sched.for_each_local_patchdata([&](const Patch &p, PatchDataLayer &pdat) {
129 if (should_insert) {
130 pdat.insert_elements(pdat_ins);
131 should_insert = false; // We insert only in first patch (no duplicates)
132 }
133 });
134 } else {
136 "Not implemented yet please keep the obj count to be "
137 "inserted below the split_threshold, sorrrrrry ...");
138 }
139
140 if (shamcomm::world_rank() == 0) {
141 logger::info_ln("DataInserterUtility", "reattributing data ...");
142 }
143
144 shambase::Timer treatrib;
145 treatrib.start();
146 // move data into the correct patches
148 ReattributeDataUtility reatrib(sched);
149 sptree.attach_buf();
150 reatrib.reatribute_patch_objects(sptree, main_field_name);
151 sched.check_patchdata_locality_correctness();
152
153 treatrib.end();
154 if (shamcomm::world_rank() == 0) {
155 logger::info_ln(
156 "DataInserterUtility", "reattributing data done in ", treatrib.get_time_str());
157 }
158 shamcomm::mpi::Barrier(MPI_COMM_WORLD);
159
160 balance_load(load_balance_update);
161
162 return sum_push;
163 }
164 };
165
166} // namespace shamrock
MPI scheduler.
std::uint32_t u32
32 bit unsigned integer
std::uint64_t u64
64 bit unsigned integer
The MPI scheduler.
void scheduler_step(bool do_split_merge, bool do_load_balancing)
scheduler step
SchedulerPatchList patch_list
handle the list of the patches of the scheduler
std::vector< shamrock::patch::Patch > global
contain the list of all patches in the simulation
Class Timer measures the time elapsed since the timer was started.
Definition time.hpp:96
std::string get_time_str() const
Converts the stored nanosecond time to a string representation.
Definition time.hpp:117
void end()
Stops the timer and stores the elapsed time in nanoseconds.
Definition time.hpp:111
void start()
Starts the timer.
Definition time.hpp:106
Class to insert data in the PatchScheduler.
u64 push_patch_data(shamrock::patch::PatchDataLayer &pdat_ins, std::string main_field_name, u32 split_threshold, std::function< void(void)> load_balance_update)
Pushes data into the scheduler.
DataInserterUtility(PatchScheduler &sched)
Constructor.
Utility class used to move the objects between patches.
void reatribute_patch_objects(SerialPatchTree< T > &sptree, std::string position_field)
Reattribute objects based on a given position field.
PatchDataLayer container class, the layout is described in patchdata_layout.
void throw_unimplemented(SourceLocation loc=SourceLocation{})
Throw a std::runtime_error saying that the function is unimplemented.
i32 world_rank()
Gives the rank of the current process in the MPI communicator.
Definition worldInfo.cpp:40
namespace for the main framework
Definition __init__.py:1
Patch object that contain generic patch information.
Definition Patch.hpp:33
Functions related to the MPI communicator.
#define ON_RANK_0(x)
Macro to execute code only on rank 0.
Definition worldInfo.hpp:73