Shamrock 2025.10.0
Astrophysical Code
Loading...
Searching...
No Matches
Solver.hpp
Go to the documentation of this file.
1// -------------------------------------------------------//
2//
3// SHAMROCK code for hydrodynamics
4// Copyright (c) 2021-2026 Timothée David--Cléris <tim.shamrock@proton.me>
5// SPDX-License-Identifier: CeCILL Free Software License Agreement v2.1
6// Shamrock is licensed under the CeCILL 2.1 License, see LICENSE for more information
7//
8// -------------------------------------------------------//
9
10#pragma once
11
20#include "SolverConfig.hpp"
21#include "shambackends/vec.hpp"
33#include <functional>
34#include <memory>
35#include <optional>
36#include <stdexcept>
37#include <variant>
38#include <vector>
39namespace shammodels::sph {
40
41 struct TimestepLog {
42 i32 rank;
43 f64 rate;
44 u64 npart;
45 f64 tcompute;
46
47 inline f64 rate_sum() { return shamalgs::collective::allreduce_sum(rate); }
48
49 inline u64 npart_sum() { return shamalgs::collective::allreduce_sum(npart); }
50
51 inline f64 tcompute_max() { return shamalgs::collective::allreduce_max(tcompute); }
52 };
53
60 template<class Tvec, template<class> class SPHKernel>
61 class Solver {
62 public:
63 using Tscal = shambase::VecComponent<Tvec>;
65 using Kernel = SPHKernel<Tscal>;
66
68
69 using u_morton = typename Config::u_morton;
70
71 static constexpr Tscal Rkern = Kernel::Rkern;
72
73 ShamrockCtx &context;
74 inline PatchScheduler &scheduler() { return shambase::get_check_ref(context.sched); }
75
77
78 Config solver_config;
79 SolverLog solve_logs;
80
82 std::optional<std::function<void(void)>> step_begin_callback;
83 std::optional<std::function<void(void)>> step_end_callback;
84 };
85 std::vector<SolverStepCallback> timestep_callbacks{};
86
87 inline void init_required_fields() { solver_config.set_layout(context.get_pdl_write()); }
88
89 // serial patch tree control
90 void gen_serial_patch_tree();
91 inline void reset_serial_patch_tree() { storage.serial_patch_tree.reset(); }
92
93 // interface_control
94 using GhostHandle = sph::BasicSPHGhostHandler<Tvec>;
95 using GhostHandleCache = typename GhostHandle::CacheMap;
96
97 inline void gen_ghost_handler(Tscal time_val) {
98
99 using CfgClass = sph::BasicSPHGhostHandlerConfig<Tvec>;
100 using BCConfig = typename CfgClass::Variant;
101
102 using BCFree = typename CfgClass::Free;
103 using BCPeriodic = typename CfgClass::Periodic;
104 using BCShearingPeriodic = typename CfgClass::ShearingPeriodic;
105
106 using SolverConfigBC = typename Config::BCConfig;
107 using SolverBCFree = typename SolverConfigBC::Free;
108 using SolverBCPeriodic = typename SolverConfigBC::Periodic;
109 using SolverBCShearingPeriodic = typename SolverConfigBC::ShearingPeriodic;
110
111 // boundary condition selections
112 if (SolverBCFree *c
113 = std::get_if<SolverBCFree>(&solver_config.boundary_config.config)) {
114 storage.ghost_handler.set(
115 GhostHandle{
116 scheduler(),
117 BCFree{},
118 storage.patch_rank_owner,
119 storage.xyzh_ghost_layout});
120 } else if (
121 SolverBCPeriodic *c
122 = std::get_if<SolverBCPeriodic>(&solver_config.boundary_config.config)) {
123 storage.ghost_handler.set(
124 GhostHandle{
125 scheduler(),
126 BCPeriodic{},
127 storage.patch_rank_owner,
128 storage.xyzh_ghost_layout});
129 } else if (
130 SolverBCShearingPeriodic *c
131 = std::get_if<SolverBCShearingPeriodic>(&solver_config.boundary_config.config)) {
132 storage.ghost_handler.set(
133 GhostHandle{
134 scheduler(),
135 BCShearingPeriodic{
136 c->shear_base, c->shear_dir, c->shear_speed * time_val, c->shear_speed},
137 storage.patch_rank_owner,
138 storage.xyzh_ghost_layout});
139 }
140 }
141 inline void reset_ghost_handler() { storage.ghost_handler.reset(); }
142
144 void build_ghost_cache();
146 void clear_ghost_cache();
147
150
151 // trees
152 using RTree = typename Config::RTree;
157
161 void reset_presteps_rint();
162
167
169 void sph_prestep(Tscal time_val, Tscal dt);
170
172 void apply_position_boundary(Tscal time_val);
173
175 void do_predictor_leapfrog(Tscal dt);
176
178 void update_artificial_viscosity(Tscal dt);
179
181 void init_ghost_layout();
182
187
189 void compute_eos_fields();
190
192 void reset_eos_fields();
193
195 void prepare_corrector();
197 void update_derivs();
204 bool apply_corrector(Tscal dt, u64 Npart_all);
205
208
209 Solver(ShamrockCtx &context) : context(context) {}
210
212 void init_solver_graph();
213
215 void vtk_do_dump(std::string filename, bool add_patch_world_id);
216
217 void set_debug_dump(bool _do_debug_dump, std::string _debug_dump_filename) {
218 solver_config.set_debug_dump(_do_debug_dump, _debug_dump_filename);
219 }
220
221 inline void print_timestep_logs() {
222 if (shamcomm::world_rank() == 0) {
223 logger::info_ln("SPH", "iteration since start :", solve_logs.get_iteration_count());
224 logger::info_ln("SPH", "time since start :", shambase::details::get_wtime(), "(s)");
225 }
226 }
227
229 TimestepLog evolve_once();
230
232 Tscal evolve_once_time_expl(Tscal t_current, Tscal dt_input) {
233 solver_config.set_time(t_current);
234 solver_config.set_next_dt(dt_input);
235 evolve_once();
236 return solver_config.get_dt_sph();
237 }
238
239 inline bool evolve_until(Tscal target_time, i32 niter_max) {
240 auto step = [&]() {
241 Tscal dt = solver_config.get_dt_sph();
242 Tscal t = solver_config.get_time();
243
244 if (t > target_time) {
246 "the target time is higher than the current time");
247 }
248
249 if (t + dt > target_time) {
250 solver_config.set_next_dt(target_time - t);
251 }
252 evolve_once();
253 };
254
255 i32 iter_count = 0;
256
257 while (solver_config.get_time() < target_time) {
258 step();
259 iter_count++;
260
261 if ((iter_count >= niter_max) && (niter_max != -1)) {
262 logger::info_ln("SPH", "stopping evolve until because of niter =", iter_count);
263 return false;
264 }
265 }
266
267 print_timestep_logs();
268
269 return true;
270 }
271 };
272
273} // namespace shammodels::sph
double f64
Alias for double.
std::uint32_t u32
32 bit unsigned integer
std::uint64_t u64
64 bit unsigned integer
std::int32_t i32
32 bit integer
The MPI scheduler.
The shamrock SPH model.
Definition Solver.hpp:61
void reset_presteps_rint()
Resets tree radius interval field.
Definition Solver.cpp:1219
void reset_merge_ghosts_fields()
Resets merged ghost field data.
Definition Solver.cpp:1484
void update_sync_load_values()
Updates load balancing values and synchronizes patch ownership.
Definition Solver.cpp:1571
bool apply_corrector(Tscal dt, u64 Npart_all)
Definition Solver.cpp:1566
void update_derivs()
Updates time derivatives and applies external forces.
Definition Solver.cpp:1556
void merge_position_ghost()
Merges ghost particle positions from neighboring patches.
Definition Solver.cpp:802
void reset_eos_fields()
Frees memory allocated for EOS fields.
Definition Solver.cpp:1510
void do_predictor_leapfrog(Tscal dt)
Performs predictor step for leapfrog integration.
Definition Solver.cpp:856
void prepare_corrector()
Saves old derivative fields for predictor-corrector integration.
Definition Solver.cpp:1516
void build_ghost_cache()
Builds ghost particle interface cache for inter-patch communication.
Definition Solver.cpp:780
void update_artificial_viscosity(Tscal dt)
Updates artificial viscosity coefficients for shock capturing.
Definition Solver.cpp:1493
TimestepLog evolve_once()
Performs one complete SPH timestep evolution.
Definition Solver.cpp:1578
void vtk_do_dump(std::string filename, bool add_patch_world_id)
Writes VTK dump file for visualization.
Definition Solver.cpp:557
void build_merged_pos_trees()
Builds spatial BVH trees for merged positions including ghosts.
Definition Solver.cpp:845
void clear_merged_pos_trees()
Clears merged position trees to free memory.
Definition Solver.cpp:850
void init_solver_graph()
Initializes the solver graph for computation pipeline.
Definition Solver.cpp:108
void sph_prestep(Tscal time_val, Tscal dt)
Performs pre-step operations for SPH timestep.
Definition Solver.cpp:920
void compute_presteps_rint()
Computes maximum smoothing length in tree nodes for neighbor search.
Definition Solver.cpp:1182
void compute_eos_fields()
Computes equation of state fields (pressure, sound speed)
Definition Solver.cpp:1504
void apply_position_boundary(Tscal time_val)
Applies position-based boundary conditions.
Definition Solver.cpp:734
void reset_neighbors_cache()
Resets neighbor cache.
Definition Solver.cpp:1249
Tscal evolve_once_time_expl(Tscal t_current, Tscal dt_input)
Evolves system by one explicit timestep with specified time and dt.
Definition Solver.hpp:232
void communicate_merge_ghosts_fields()
Communicates and merges ghost particle fields across processes.
Definition Solver.cpp:1254
void clear_ghost_cache()
Clears ghost particle cache to free memory.
Definition Solver.cpp:796
void init_ghost_layout()
Initializes data layout for ghost particle fields.
Definition Solver.cpp:1167
void start_neighbors_cache()
Builds neighbor particle cache for SPH calculations.
Definition Solver.cpp:1224
This header file contains utility functions related to exception handling in the code.
void throw_with_loc(std::string message, SourceLocation loc=SourceLocation{})
Throw an exception and append the source location to it.
T & get_check_ref(const std::unique_ptr< T > &ptr, SourceLocation loc=SourceLocation())
Takes a std::unique_ptr and returns a reference to the object it holds. It throws a std::runtime_erro...
Definition memory.hpp:110
i32 world_rank()
Gives the rank of the current process in the MPI communicator.
Definition worldInfo.cpp:40
namespace for the sph model
std::shared_ptr< shamrock::solvergraph::RankGetter > patch_rank_owner
Patch rank ownership.
std::shared_ptr< shamrock::patch::PatchDataLayerLayout > xyzh_ghost_layout
Ghost data layout and merged data.
Component< SerialPatchTree< Tvec > > serial_patch_tree
Serial patch tree for load balancing.
std::shared_ptr< solvergraph::GhostHandlerEdge< Tvec > > ghost_handler
Ghost handler for boundary particles.
The configuration for a sph solver.
BCConfig boundary_config
Boundary condition configuration.
void set_time(Tscal t)
Set the current time.
Tscal get_time()
Get the current time.
void set_debug_dump(bool _do_debug_dump, std::string _debug_dump_filename)
Set whether to dump debug information to file.
u32 u_morton
The type of the Morton code for the tree.
Tscal get_dt_sph()
Get the time step for the next iteration.
BCConfig< Tvec > BCConfig
Configuration of the boundary conditions.
void set_next_dt(Tscal dt)
Set the time step for the next iteration.
Class holding the logs of the solver /todo add a variable to keep only a definite number of steps in ...
Definition SolverLog.hpp:33