72 using Tscal = shambase::VecComponent<Tvec>;
73 static constexpr u32 dim = shambase::VectorProperties<Tvec>::dimension;
74 using Kernel = SPHKernel<Tscal>;
80 static constexpr Tscal Rkern = Kernel::Rkern;
91 std::optional<std::function<void(
void)>> step_begin_callback;
92 std::optional<std::function<void(
void)>> step_end_callback;
94 std::vector<SolverStepCallback> timestep_callbacks{};
96 inline void init_required_fields() { solver_config.set_layout(context.get_pdl_write()); }
99 void gen_serial_patch_tree();
100 inline void reset_serial_patch_tree() { storage.serial_patch_tree.reset(); }
103 using GhostHandle = sph::BasicSPHGhostHandler<Tvec>;
104 using GhostHandleCache =
typename GhostHandle::CacheMap;
106 inline void gen_ghost_handler(Tscal time_val) {
108 using CfgClass = sph::BasicSPHGhostHandlerConfig<Tvec>;
109 using BCConfig =
typename CfgClass::Variant;
111 using BCFree =
typename CfgClass::Free;
112 using BCPeriodic =
typename CfgClass::Periodic;
113 using BCShearingPeriodic =
typename CfgClass::ShearingPeriodic;
116 using SolverBCFree =
typename SolverConfigBC::Free;
117 using SolverBCPeriodic =
typename SolverConfigBC::Periodic;
118 using SolverBCShearingPeriodic =
typename SolverConfigBC::ShearingPeriodic;
123 storage.ghost_handler.set(
127 storage.patch_rank_owner,
128 storage.xyzh_ghost_layout});
131 = std::get_if<SolverBCPeriodic>(&solver_config.boundary_config.config)) {
132 storage.ghost_handler.set(
136 storage.patch_rank_owner,
137 storage.xyzh_ghost_layout});
139 SolverBCShearingPeriodic *c
140 = std::get_if<SolverBCShearingPeriodic>(&solver_config.boundary_config.config)) {
141 storage.ghost_handler.set(
145 c->shear_base, c->shear_dir, c->shear_speed * time_val, c->shear_speed},
146 storage.patch_rank_owner,
147 storage.xyzh_ghost_layout});
150 inline void reset_ghost_handler() { storage.ghost_handler.reset(); }
161 using RTree =
typename Config::RTree;
215 Solver(ShamrockCtx &context) : context(context) {}
221 void vtk_do_dump(std::string filename,
bool add_patch_world_id);
223 void set_debug_dump(
bool _do_debug_dump, std::string _debug_dump_filename) {
224 solver_config.set_debug_dump(_do_debug_dump, _debug_dump_filename);
227 inline void print_timestep_logs() {
229 logger::info_ln(
"SPH",
"iteration since start :", solve_logs.get_iteration_count());
230 logger::info_ln(
"SPH",
"time since start :", shambase::details::get_wtime(),
"(s)");
239 solver_config.set_time(t_current);
240 solver_config.set_next_dt(dt_input);
242 return solver_config.get_dt_sph();
246 Tscal target_time,
i32 niter_max,
f64 max_walltime = -1) {
248 const bool niter_limit_active = (niter_max >= 0);
249 const bool walltime_limit_active = (max_walltime >= 0);
255 "evolve_until (target_time = {:.2f}s, niter_max = {}, max_walltime = "
262 auto synced_wtime = [&]() ->
f64 {
263 if (walltime_limit_active) {
270 Tscal dt = solver_config.get_dt_sph();
271 Tscal t = solver_config.get_time();
273 if (t > target_time) {
275 "the target time is higher than the current time");
278 if (t + dt > target_time) {
279 solver_config.set_next_dt(target_time - t);
284 f64 start_wall_time = (walltime_limit_active) ? synced_wtime() : 0;
286 i32 next_walltime_check_iter
287 = walltime_limit_active ? 1 : std::numeric_limits<i32>::max();
291 while (solver_config.get_time() < target_time) {
296 if (niter_limit_active && iter_count >= niter_max) {
299 "SPH",
"stopping evolve until because of niter =", iter_count);
302 .reach_target_time =
false,
303 .reach_niter_max =
true,
304 .reach_max_walltime =
false,
305 .iter_count = iter_count,
310 if (walltime_limit_active && iter_count >= next_walltime_check_iter) {
311 f64 global_walltime = synced_wtime();
314 if (global_walltime >= max_walltime) {
319 "stopping evolve until because of "
320 "max_walltime = {:.2f}s > {:.2f}s",
325 .reach_target_time =
false,
326 .reach_niter_max =
false,
327 .reach_max_walltime =
true,
328 .iter_count = iter_count,
333 = (global_walltime - start_wall_time) /
static_cast<f64>(iter_count);
335 auto get_remaining_iters = [&](
f64 delta_walltime,
f64 factor) ->
i32 {
336 if (sec_per_iter > 0) {
337 f64 tmp = factor * delta_walltime / sec_per_iter;
338 if (tmp > std::numeric_limits<i32>::max()) {
339 return std::numeric_limits<i32>::max();
341 return static_cast<i32>(tmp);
346 i32 iters_to_limit = get_remaining_iters(max_walltime - global_walltime, 0.25);
347 i32 iters_to_next_check = iters_to_limit;
349 next_walltime_check_iter = iter_count + std::max(1, iters_to_next_check);
355 "next walltime check in {:.2f}s (niter = {}) global walltime = "
356 "{:.2f}s (max_walltime = {:.2f}s)",
357 iters_to_next_check * sec_per_iter,
365 print_timestep_logs();
368 .reach_target_time =
true,
369 .reach_niter_max =
false,
370 .reach_max_walltime =
false,
371 .iter_count = iter_count,