Shamrock 2025.10.0
Astrophysical Code
Loading...
Searching...
No Matches
Model.hpp
Go to the documentation of this file.
1// -------------------------------------------------------//
2//
3// SHAMROCK code for hydrodynamics
4// Copyright (c) 2021-2026 Timothée David--Cléris <tim.shamrock@proton.me>
5// SPDX-License-Identifier: CeCILL Free Software License Agreement v2.1
6// Shamrock is licensed under the CeCILL 2.1 License, see LICENSE for more information
7//
8// -------------------------------------------------------//
9
10#pragma once
11
30#include "shambase/string.hpp"
34#include "shambackends/vec.hpp"
35#include "shamcomm/logs.hpp"
46#include <pybind11/functional.h>
47#include <stdexcept>
48#include <vector>
49
50namespace shammodels::gsph {
51
62 template<class Tvec, template<class> class SPHKernel>
63 class Model {
64 public:
65 using Tscal = shambase::VecComponent<Tvec>;
67 using Kernel = SPHKernel<Tscal>;
68
69 using Solver = Solver<Tvec, SPHKernel>;
70 using SolverConfig = typename Solver::Config;
71
72 ShamrockCtx &ctx;
73 Solver solver;
74
75 Model(ShamrockCtx &ctx) : ctx(ctx), solver(ctx) {};
76
78 // Setup functions
80
82 void init();
83
86 inline void init_scheduler(u32 crit_split, u32 crit_merge) {
87 solver.solver_config.scheduler_conf.split_load_value = crit_split;
88 solver.solver_config.scheduler_conf.merge_load_value = crit_merge;
89 init();
90 }
91
92 template<std::enable_if_t<dim == 3, int> = 0>
93 inline Tvec get_box_dim_fcc_3d(Tscal dr, u32 xcnt, u32 ycnt, u32 zcnt) {
94 return generic::setup::generators::get_box_dim(dr, xcnt, ycnt, zcnt);
95 }
96
97 inline void set_cfl_cour(Tscal cfl_cour) {
98 solver.solver_config.cfl_config.cfl_cour = cfl_cour;
99 }
100
101 inline void set_cfl_force(Tscal cfl_force) {
102 solver.solver_config.cfl_config.cfl_force = cfl_force;
103 }
104
105 inline void set_particle_mass(Tscal gpart_mass) {
106 solver.solver_config.gpart_mass = gpart_mass;
107 }
108
109 inline Tscal get_particle_mass() { return solver.solver_config.gpart_mass; }
110
111 inline void resize_simulation_box(std::pair<Tvec, Tvec> box) {
112 ctx.set_coord_domain_bound({box.first, box.second});
113 }
114
115 void do_vtk_dump(std::string filename, bool add_patch_world_id) {
116 solver.vtk_do_dump(filename, add_patch_world_id);
117 }
118
119 u64 get_total_part_count();
120
121 f64 total_mass_to_part_mass(f64 totmass);
122
123 std::pair<Tvec, Tvec> get_ideal_fcc_box(Tscal dr, std::pair<Tvec, Tvec> box);
124 std::pair<Tvec, Tvec> get_ideal_hcp_box(Tscal dr, std::pair<Tvec, Tvec> box);
125
126 Tscal get_hfact() { return Kernel::hfactd; }
127
128 Tscal rho_h(Tscal h) {
129 return shamrock::sph::rho_h(solver.solver_config.gpart_mass, h, Kernel::hfactd);
130 }
131
132 void add_cube_fcc_3d(Tscal dr, std::pair<Tvec, Tvec> _box);
133 void add_cube_hcp_3d(Tscal dr, std::pair<Tvec, Tvec> _box);
134
136 // Field manipulation
138
157 template<class T>
159 std::string field_name, const std::function<T(Tvec)> pos_to_val) {
160
161 StackEntry stack_loc{};
162 PatchScheduler &sched = shambase::get_check_ref(ctx.sched);
163 sched.patch_data.for_each_patchdata(
164 [&](u64 patch_id, shamrock::patch::PatchDataLayer &pdat) {
166 = pdat.template get_field<Tvec>(sched.pdl_old().get_field_idx<Tvec>("xyz"));
167
169 = pdat.template get_field<T>(sched.pdl_old().get_field_idx<T>(field_name));
170
171 if (f.get_nvar() != 1) {
173 }
174
175 {
176 auto &buf = f.get_buf();
177 auto acc = buf.copy_to_stdvec();
178
179 auto &buf_xyz = xyz.get_buf();
180 auto acc_xyz = buf_xyz.copy_to_stdvec();
181
182 for (u32 i = 0; i < f.get_obj_cnt(); i++) {
183 Tvec r = acc_xyz[i];
184 acc[i] = pos_to_val(r);
185 }
186
187 buf.copy_from_stdvec(acc);
188 buf_xyz.copy_from_stdvec(acc_xyz);
189 }
190 });
191 }
192
214 template<class T>
215 inline void set_field_in_box(
216 std::string field_name, T val, std::pair<Tvec, Tvec> box, u32 ivar = 0) {
217 StackEntry stack_loc{};
218 PatchScheduler &sched = shambase::get_check_ref(ctx.sched);
219 sched.patch_data.for_each_patchdata(
220 [&](u64 patch_id, shamrock::patch::PatchDataLayer &pdat) {
222 = pdat.template get_field<Tvec>(sched.pdl_old().get_field_idx<Tvec>("xyz"));
223
225 = pdat.template get_field<T>(sched.pdl_old().get_field_idx<T>(field_name));
226
227 u32 nvar = f.get_nvar();
228
229 // Validate ivar parameter to prevent out-of-bounds access
230 if (ivar >= nvar) {
232 "set_field_in_box: ivar ({}) >= f.get_nvar ({}) for field {}",
233 ivar,
234 nvar,
235 field_name));
236 }
237
238 {
239 auto acc = f.get_buf().template mirror_to<sham::host>();
240 auto acc_xyz = xyz.get_buf().template mirror_to<sham::host>();
241
242 for (u32 i = 0; i < f.get_obj_cnt(); i++) {
243 Tvec r = acc_xyz[i];
244
245 if (BBAA::is_coord_in_range(r, std::get<0>(box), std::get<1>(box))) {
246 acc[i * nvar + ivar] = val;
247 }
248 }
249 }
250 });
251 }
252
273 template<class T>
274 inline void set_field_in_sphere(std::string field_name, T val, Tvec center, Tscal radius) {
275 StackEntry stack_loc{};
276 PatchScheduler &sched = shambase::get_check_ref(ctx.sched);
277 sched.patch_data.for_each_patchdata(
278 [&](u64 patch_id, shamrock::patch::PatchDataLayer &pdat) {
280 = pdat.template get_field<Tvec>(sched.pdl_old().get_field_idx<Tvec>("xyz"));
281
283 = pdat.template get_field<T>(sched.pdl_old().get_field_idx<T>(field_name));
284
285 if (f.get_nvar() != 1) {
287 }
288
289 Tscal r2 = radius * radius;
290 {
291 auto acc = f.get_buf().template mirror_to<sham::host>();
292 auto acc_xyz = xyz.get_buf().template mirror_to<sham::host>();
293
294 for (u32 i = 0; i < f.get_obj_cnt(); i++) {
295 Tvec dr = acc_xyz[i] - center;
296
297 if (sycl::dot(dr, dr) < r2) {
298 acc[i] = val;
299 }
300 }
301 }
302 });
303 }
304
305 template<class T>
306 inline T get_sum(std::string name) {
307 PatchScheduler &sched = shambase::get_check_ref(ctx.sched);
309
310 StackEntry stack_loc{};
311 sched.patch_data.for_each_patchdata(
312 [&](u64 patch_id, shamrock::patch::PatchDataLayer &pdat) {
314 = pdat.template get_field<T>(sched.pdl_old().get_field_idx<T>(name));
315
316 sum += xyz.compute_sum();
317 });
318
319 return shamalgs::collective::allreduce_sum(sum);
320 }
321
323 // Solver configuration
325
326 inline SolverConfig gen_default_config() {
327 SolverConfig cfg;
328 cfg.set_riemann_iterative(); // Default to iterative Riemann solver
329 cfg.set_reconstruct_piecewise_constant(); // Default to 1st order (piecewise constant)
330 cfg.set_eos_adiabatic(Tscal{1.4});
331 cfg.set_boundary_periodic();
332 return cfg;
333 }
334
335 inline void set_solver_config(SolverConfig cfg) {
336 if (ctx.is_scheduler_initialized()) {
338 "Cannot change solver config after scheduler is initialized");
339 }
340 cfg.check_config();
341 solver.solver_config = cfg;
342 }
343
344 inline f64 solver_logs_last_rate() { return solver.solve_logs.get_last_rate(); }
345 inline u64 solver_logs_last_obj_count() { return solver.solve_logs.get_last_obj_count(); }
346 inline shamsys::SystemMetrics solver_logs_last_system_metrics() {
347 return solver.solve_logs.get_last_system_metrics();
348 }
349
351 // I/O (uses shared ShamrockDump mechanism like SPH)
353
354 inline void load_from_dump(std::string fname) {
355 if (shamcomm::world_rank() == 0) {
356 logger::info_ln("GSPH", "Loading state from dump", fname);
357 }
358
359 std::string metadata_user{};
360 shamrock::load_shamrock_dump(fname, metadata_user, ctx);
361
362 nlohmann::json j = nlohmann::json::parse(metadata_user);
363 j.at("solver_config").get_to(solver.solver_config);
364
365 solver.init_ghost_layout();
366 solver.init_solver_graph();
367
368 PatchScheduler &sched = shambase::get_check_ref(ctx.sched);
369 sched.owned_patch_id = sched.patch_list.build_local();
372 sched.update_local_load_value([&](shamrock::patch::Patch p) {
373 return sched.patch_data.owned_data.get(p.id_patch).get_obj_cnt();
374 });
375 }
376
377 inline void dump(std::string fname) {
378 if (shamcomm::world_rank() == 0) {
379 logger::info_ln("GSPH", "Dumping state to", fname);
380 }
381
382 solver.update_sync_load_values();
383
384 nlohmann::json metadata;
385 metadata["solver_config"] = solver.solver_config;
386
388 fname, metadata.dump(4), shambase::get_check_ref(ctx.sched));
389 }
390
392 // Simulation control
394
395 TimestepLog timestep() { return solver.evolve_once(); }
396
397 inline void evolve_once() {
398 solver.evolve_once();
399 solver.print_timestep_logs();
400 }
401
402 inline bool evolve_until(Tscal target_time, i32 niter_max = -1) {
403 return solver.evolve_until(target_time, niter_max);
404 }
405 };
406
407} // namespace shammodels::gsph
Header file describing a Node Instance.
double f64
Alias for double.
std::uint32_t u32
32 bit unsigned integer
std::uint64_t u64
64 bit unsigned integer
std::int32_t i32
32 bit integer
The MPI scheduler.
SchedulerPatchData patch_data
handle the data of the patches of the scheduler
SchedulerPatchList patch_list
handle the list of the patches of the scheduler
std::unordered_set< u64 > owned_patch_id
(owned_patch_id = patch_list.build_local())
std::unordered_set< u64 > build_local()
select owned patches owned by the node to rebuild local
void build_local_idx_map()
recompute id_patch_to_local_idx
void build_global_idx_map()
recompute id_patch_to_global_idx
bool is_scheduler_initialized()
returns true if the scheduler is initialized
The GSPH Model class.
Definition Model.hpp:63
void init_scheduler(u32 crit_split, u32 crit_merge)
Definition Model.hpp:86
void apply_field_from_position(std::string field_name, const std::function< T(Tvec)> pos_to_val)
Apply a position-dependent function to initialize a field.
Definition Model.hpp:158
void set_field_in_sphere(std::string field_name, T val, Tvec center, Tscal radius)
Set field value for particles within a spherical region.
Definition Model.hpp:274
void init()
Initialise the model and all the related data structures (patch scheduler in particular)
Definition Model.cpp:40
void set_field_in_box(std::string field_name, T val, std::pair< Tvec, Tvec > box, u32 ivar=0)
Set field value for particles within a box region.
Definition Model.hpp:215
u32 get_field_idx(const std::string &field_name) const
Get the field id if matching name & type.
PatchDataLayer container class, the layout is described in patchdata_layout.
shambase::DistributedData< PatchData > owned_data
map container for patchdata owned by the current node (layout : id_patch,data)
Class holding the value of numerous constants generated from the following source.
This header file contains utility functions related to exception handling in the code.
MPI string gather / allgather helpers (declarations; implementations in shamalgs/src/collective/gathe...
GSPH Solver class.
void throw_with_loc(std::string message, SourceLocation loc=SourceLocation{})
Throw an exception and append the source location to it.
T & get_check_ref(const std::unique_ptr< T > &ptr, SourceLocation loc=SourceLocation())
Takes a std::unique_ptr and returns a reference to the object it holds. It throws a std::runtime_erro...
Definition memory.hpp:110
void throw_unimplemented(SourceLocation loc=SourceLocation{})
Throw a std::runtime_error saying that the function is unimplemented.
i32 world_rank()
Gives the rank of the current process in the MPI communicator.
Definition worldInfo.cpp:40
void load_shamrock_dump(std::string fname, std::string &metadata_user, ShamrockCtx &ctx)
Load a Shamrock dump file and restore the state of the patches and retreive user metadata.
void write_shamrock_dump(std::string fname, std::string metadata_user, PatchScheduler &sched)
Write a Shamrock dump file containing the current state of the patches and user supplied metadata.
The configuration for a GSPH solver.
Patch object that contain generic patch information.
Definition Patch.hpp:33