Shamrock 2025.10.0
Astrophysical Code
Loading...
Searching...
No Matches
Solver.hpp
Go to the documentation of this file.
1// -------------------------------------------------------//
2//
3// SHAMROCK code for hydrodynamics
4// Copyright (c) 2021-2026 Timothée David--Cléris <tim.shamrock@proton.me>
5// SPDX-License-Identifier: CeCILL Free Software License Agreement v2.1
6// Shamrock is licensed under the CeCILL 2.1 License, see LICENSE for more information
7//
8// -------------------------------------------------------//
9
10#pragma once
11
18
20#include "SolverConfig.hpp"
21#include "shambackends/vec.hpp"
33#include <functional>
34#include <limits>
35#include <memory>
36#include <optional>
37#include <stdexcept>
38#include <variant>
39#include <vector>
40namespace shammodels::sph {
41
42 struct TimestepLog {
43 i32 rank;
44 f64 rate;
45 u64 npart;
46 f64 tcompute;
47
48 inline f64 rate_sum() { return shamalgs::collective::allreduce_sum(rate); }
49
50 inline u64 npart_sum() { return shamalgs::collective::allreduce_sum(npart); }
51
52 inline f64 tcompute_max() { return shamalgs::collective::allreduce_max(tcompute); }
53 };
54
56 bool reach_target_time;
57 bool reach_niter_max;
58 bool reach_max_walltime;
59
60 i32 iter_count;
61 };
62
69 template<class Tvec, template<class> class SPHKernel>
70 class Solver {
71 public:
72 using Tscal = shambase::VecComponent<Tvec>;
73 static constexpr u32 dim = shambase::VectorProperties<Tvec>::dimension;
74 using Kernel = SPHKernel<Tscal>;
75
76 using Config = SolverConfig<Tvec, SPHKernel>;
77
78 using u_morton = typename Config::u_morton;
79
80 static constexpr Tscal Rkern = Kernel::Rkern;
81
82 ShamrockCtx &context;
83 inline PatchScheduler &scheduler() { return shambase::get_check_ref(context.sched); }
84
86
87 Config solver_config;
88 SolverLog solve_logs;
89
91 std::optional<std::function<void(void)>> step_begin_callback;
92 std::optional<std::function<void(void)>> step_end_callback;
93 };
94 std::vector<SolverStepCallback> timestep_callbacks{};
95
96 inline void init_required_fields() { solver_config.set_layout(context.get_pdl_write()); }
97
98 // serial patch tree control
99 void gen_serial_patch_tree();
100 inline void reset_serial_patch_tree() { storage.serial_patch_tree.reset(); }
101
102 // interface_control
103 using GhostHandle = sph::BasicSPHGhostHandler<Tvec>;
104 using GhostHandleCache = typename GhostHandle::CacheMap;
105
106 inline void gen_ghost_handler(Tscal time_val) {
107
108 using CfgClass = sph::BasicSPHGhostHandlerConfig<Tvec>;
109 using BCConfig = typename CfgClass::Variant;
110
111 using BCFree = typename CfgClass::Free;
112 using BCPeriodic = typename CfgClass::Periodic;
113 using BCShearingPeriodic = typename CfgClass::ShearingPeriodic;
114
115 using SolverConfigBC = typename Config::BCConfig;
116 using SolverBCFree = typename SolverConfigBC::Free;
117 using SolverBCPeriodic = typename SolverConfigBC::Periodic;
118 using SolverBCShearingPeriodic = typename SolverConfigBC::ShearingPeriodic;
119
120 // boundary condition selections
121 if (SolverBCFree *c
122 = std::get_if<SolverBCFree>(&solver_config.boundary_config.config)) {
123 storage.ghost_handler.set(
124 GhostHandle{
125 scheduler(),
126 BCFree{},
127 storage.patch_rank_owner,
128 storage.xyzh_ghost_layout});
129 } else if (
130 SolverBCPeriodic *c
131 = std::get_if<SolverBCPeriodic>(&solver_config.boundary_config.config)) {
132 storage.ghost_handler.set(
133 GhostHandle{
134 scheduler(),
135 BCPeriodic{},
136 storage.patch_rank_owner,
137 storage.xyzh_ghost_layout});
138 } else if (
139 SolverBCShearingPeriodic *c
140 = std::get_if<SolverBCShearingPeriodic>(&solver_config.boundary_config.config)) {
141 storage.ghost_handler.set(
142 GhostHandle{
143 scheduler(),
144 BCShearingPeriodic{
145 c->shear_base, c->shear_dir, c->shear_speed * time_val, c->shear_speed},
146 storage.patch_rank_owner,
147 storage.xyzh_ghost_layout});
148 }
149 }
150 inline void reset_ghost_handler() { storage.ghost_handler.reset(); }
151
153 void build_ghost_cache();
155 void clear_ghost_cache();
156
159
160 // trees
161 using RTree = typename Config::RTree;
166
170 void reset_presteps_rint();
171
176
178 void sph_prestep(Tscal time_val, Tscal dt);
179
181 void apply_position_boundary(Tscal time_val);
182
184 void update_artificial_viscosity(Tscal dt);
185
187 void init_ghost_layout();
188
193
195 void compute_eos_fields();
196
198 void reset_eos_fields();
199
201 void prepare_corrector();
203 void update_derivs(Tscal dt_hydro);
210 bool apply_corrector(Tscal dt, u64 Npart_all);
211
214
215 Solver(ShamrockCtx &context) : context(context) {}
216
218 void init_solver_graph();
219
221 void vtk_do_dump(std::string filename, bool add_patch_world_id);
222
223 void set_debug_dump(bool _do_debug_dump, std::string _debug_dump_filename) {
224 solver_config.set_debug_dump(_do_debug_dump, _debug_dump_filename);
225 }
226
227 inline void print_timestep_logs() {
228 if (shamcomm::world_rank() == 0) {
229 logger::info_ln("SPH", "iteration since start :", solve_logs.get_iteration_count());
230 logger::info_ln("SPH", "time since start :", shambase::details::get_wtime(), "(s)");
231 }
232 }
233
235 TimestepLog evolve_once();
236
238 Tscal evolve_once_time_expl(Tscal t_current, Tscal dt_input) {
239 solver_config.set_time(t_current);
240 solver_config.set_next_dt(dt_input);
241 evolve_once();
242 return solver_config.get_dt_sph();
243 }
244
245 inline EvolveUntilResults evolve_until(
246 Tscal target_time, i32 niter_max, f64 max_walltime = -1) {
247
248 const bool niter_limit_active = (niter_max >= 0);
249 const bool walltime_limit_active = (max_walltime >= 0);
250
251 if (shamcomm::world_rank() == 0) {
253 "SPH",
254 shambase::format(
255 "evolve_until (target_time = {:.2f}s, niter_max = {}, max_walltime = "
256 "{:.2f}s)",
257 target_time,
258 niter_max,
259 max_walltime));
260 }
261
262 auto synced_wtime = [&]() -> f64 {
263 if (walltime_limit_active) {
264 return shamalgs::collective::allreduce_max(shambase::details::get_wtime());
265 }
266 return 0;
267 };
268
269 auto step = [&]() {
270 Tscal dt = solver_config.get_dt_sph();
271 Tscal t = solver_config.get_time();
272
273 if (t > target_time) {
275 "the target time is higher than the current time");
276 }
277
278 if (t + dt > target_time) {
279 solver_config.set_next_dt(target_time - t);
280 }
281 evolve_once();
282 };
283
284 f64 start_wall_time = (walltime_limit_active) ? synced_wtime() : 0;
285
286 i32 next_walltime_check_iter
287 = walltime_limit_active ? 1 : std::numeric_limits<i32>::max();
288
289 i32 iter_count = 0;
290
291 while (solver_config.get_time() < target_time) {
292 step();
293 iter_count++;
294
295 // if the iteration count is greater than the maximum iteration count
296 if (niter_limit_active && iter_count >= niter_max) {
297 if (shamcomm::world_rank() == 0) {
298 logger::info_ln(
299 "SPH", "stopping evolve until because of niter =", iter_count);
300 }
301 return {
302 .reach_target_time = false,
303 .reach_niter_max = true,
304 .reach_max_walltime = false,
305 .iter_count = iter_count,
306 };
307 }
308
309 // if walltime limit is active and the next walltime check is due
310 if (walltime_limit_active && iter_count >= next_walltime_check_iter) {
311 f64 global_walltime = synced_wtime();
312
313 // if the global walltime is greater than the max walltime
314 if (global_walltime >= max_walltime) {
315 if (shamcomm::world_rank() == 0) {
316 logger::info_ln(
317 "SPH",
318 shambase::format(
319 "stopping evolve until because of "
320 "max_walltime = {:.2f}s > {:.2f}s",
321 global_walltime,
322 max_walltime));
323 }
324 return {
325 .reach_target_time = false,
326 .reach_niter_max = false,
327 .reach_max_walltime = true,
328 .iter_count = iter_count,
329 };
330 }
331
332 f64 sec_per_iter
333 = (global_walltime - start_wall_time) / static_cast<f64>(iter_count);
334
335 auto get_remaining_iters = [&](f64 delta_walltime, f64 factor) -> i32 {
336 if (sec_per_iter > 0) {
337 f64 tmp = factor * delta_walltime / sec_per_iter;
338 if (tmp > std::numeric_limits<i32>::max()) {
339 return std::numeric_limits<i32>::max();
340 }
341 return static_cast<i32>(tmp);
342 }
343 return 1000; // default to 1000 iterations if sec_per_iter is 0
344 };
345
346 i32 iters_to_limit = get_remaining_iters(max_walltime - global_walltime, 0.25);
347 i32 iters_to_next_check = iters_to_limit;
348
349 next_walltime_check_iter = iter_count + std::max(1, iters_to_next_check);
350
351 if (shamcomm::world_rank() == 0) {
352 logger::info_ln(
353 "SPH",
354 shambase::format(
355 "next walltime check in {:.2f}s (niter = {}) global walltime = "
356 "{:.2f}s (max_walltime = {:.2f}s)",
357 iters_to_next_check * sec_per_iter,
358 iters_to_next_check,
359 global_walltime,
360 max_walltime));
361 }
362 }
363 }
364
365 print_timestep_logs();
366
367 return {
368 .reach_target_time = true,
369 .reach_niter_max = false,
370 .reach_max_walltime = false,
371 .iter_count = iter_count,
372 };
373 }
374 };
375
376} // namespace shammodels::sph
double f64
Alias for double.
std::uint32_t u32
32 bit unsigned integer
std::uint64_t u64
64 bit unsigned integer
std::int32_t i32
32 bit integer
The MPI scheduler.
void reset_presteps_rint()
Resets tree radius interval field.
Definition Solver.cpp:1196
void reset_merge_ghosts_fields()
Resets merged ghost field data.
Definition Solver.cpp:1473
void update_sync_load_values()
Updates load balancing values and synchronizes patch ownership.
Definition Solver.cpp:1565
bool apply_corrector(Tscal dt, u64 Npart_all)
Definition Solver.cpp:1560
void merge_position_ghost()
Merges ghost particle positions from neighboring patches.
Definition Solver.cpp:843
void reset_eos_fields()
Frees memory allocated for EOS fields.
Definition Solver.cpp:1499
void prepare_corrector()
Saves old derivative fields for predictor-corrector integration.
Definition Solver.cpp:1505
void build_ghost_cache()
Builds ghost particle interface cache for inter-patch communication.
Definition Solver.cpp:821
void update_artificial_viscosity(Tscal dt)
Updates artificial viscosity coefficients for shock capturing.
Definition Solver.cpp:1482
TimestepLog evolve_once()
Performs one complete SPH timestep evolution.
Definition Solver.cpp:1623
void vtk_do_dump(std::string filename, bool add_patch_world_id)
Writes VTK dump file for visualization.
Definition Solver.cpp:598
void update_derivs(Tscal dt_hydro)
Updates time derivatives and applies external forces.
Definition Solver.cpp:1550
void build_merged_pos_trees()
Builds spatial BVH trees for merged positions including ghosts.
Definition Solver.cpp:886
void clear_merged_pos_trees()
Clears merged position trees to free memory.
Definition Solver.cpp:891
void init_solver_graph()
Initializes the solver graph for computation pipeline.
Definition Solver.cpp:112
void sph_prestep(Tscal time_val, Tscal dt)
Performs pre-step operations for SPH timestep.
Definition Solver.cpp:897
void compute_presteps_rint()
Computes maximum smoothing length in tree nodes for neighbor search.
Definition Solver.cpp:1159
void compute_eos_fields()
Computes equation of state fields (pressure, sound speed).
Definition Solver.cpp:1493
void apply_position_boundary(Tscal time_val)
Applies position-based boundary conditions.
Definition Solver.cpp:775
void reset_neighbors_cache()
Resets neighbor cache.
Definition Solver.cpp:1226
Tscal evolve_once_time_expl(Tscal t_current, Tscal dt_input)
Evolves system by one explicit timestep with specified time and dt.
Definition Solver.hpp:238
void communicate_merge_ghosts_fields()
Communicates and merges ghost particle fields across processes.
Definition Solver.cpp:1231
void clear_ghost_cache()
Clears ghost particle cache to free memory.
Definition Solver.cpp:837
void init_ghost_layout()
Initializes data layout for ghost particle fields.
Definition Solver.cpp:1144
void start_neighbors_cache()
Builds neighbor particle cache for SPH calculations.
Definition Solver.cpp:1201
This header file contains utility functions related to exception handling in the code.
T & get_check_ref(const std::unique_ptr< T > &ptr, SourceLocation loc=SourceLocation())
Takes a std::unique_ptr and returns a reference to the object it holds. It throws a std::runtime_erro...
Definition memory.hpp:110
ExcptTypes make_except_with_loc(std::string message, SourceLocation loc=SourceLocation{})
Create an exception with a message and a location.
i32 world_rank()
Gives the rank of the current process in the MPI communicator.
Definition worldInfo.cpp:40
namespace for the sph model
void info_ln(std::string module_name, Types... var2)
Prints a log message with multiple arguments followed by a newline.
Definition logs.hpp:133
f64 get_wtime()
Returns the current wall clock time in seconds.
The configuration for a sph solver.
BCConfig boundary_config
Boundary condition configuration.
u32 u_morton
The type of the Morton code for the tree.
BCConfig< Tvec > BCConfig
Configuration of the boundary conditions.
Class holding the logs of the solver /todo add a variable to keep only a definite number of steps in ...
Definition SolverLog.hpp:33