Shamrock 2025.10.0
Astrophysical Code
Loading...
Searching...
No Matches
Model.hpp
Go to the documentation of this file.
1// -------------------------------------------------------//
2//
3// SHAMROCK code for hydrodynamics
4// Copyright (c) 2021-2026 Timothée David--Cléris <tim.shamrock@proton.me>
5// SPDX-License-Identifier: CeCILL Free Software License Agreement v2.1
6// Shamrock is licensed under the CeCILL 2.1 License, see LICENSE for more information
7//
8// -------------------------------------------------------//
9
10#pragma once
11
21#include "shambase/memory.hpp"
22#include "shambackends/vec.hpp"
30
32
33 template<class Tvec, class TgridVec>
34 class Model {
35 public:
36 using Tscal = shambase::VecComponent<Tvec>;
38 ShamrockCtx &ctx;
39
40 using Solver = Solver<Tvec, TgridVec>;
41 Solver solver;
42
43 Model(ShamrockCtx &ctx) : ctx(ctx), solver(ctx) {};
44
48
50 void init();
51
54 inline void init_scheduler(u32 crit_split, u32 crit_merge) {
55 solver.solver_config.scheduler_conf.split_load_value = crit_split;
56 solver.solver_config.scheduler_conf.merge_load_value = crit_merge;
57 init();
58 }
59
60 void make_base_grid(TgridVec bmin, TgridVec cell_size, u32_3 cell_count);
61
62 void dump_vtk(std::string filename);
63
64 template<class T>
65 inline void set_field_value_lambda(
66 std::string field_name,
67 const std::function<T(Tvec, Tvec)> pos_to_val,
68 const i32 offset) {
69
70 StackEntry stack_loc{};
71
72 using Block = typename Solver::Config::AMRBlock;
73
74 PatchScheduler &sched = shambase::get_check_ref(ctx.sched);
75 sched.patch_data.for_each_patchdata([&](u64 patch_id,
77 sham::DeviceBuffer<TgridVec> &buf_cell_min = pdat.get_field_buf_ref<TgridVec>(0);
78 sham::DeviceBuffer<TgridVec> &buf_cell_max = pdat.get_field_buf_ref<TgridVec>(1);
79
81 = pdat.template get_field<T>(sched.pdl_old().get_field_idx<T>(field_name));
82
83 auto acc = f.get_buf().copy_to_stdvec();
84
85 auto f_nvar = f.get_nvar() / Block::block_size;
86
87 auto cell_min = buf_cell_min.copy_to_stdvec();
88 auto cell_max = buf_cell_max.copy_to_stdvec();
89
90 Tscal scale_factor = solver.solver_config.grid_coord_to_pos_fact;
91 for (u32 i = 0; i < pdat.get_obj_cnt(); i++) {
92 Tvec block_min = cell_min[i].template convert<Tscal>() * scale_factor;
93 Tvec block_max = cell_max[i].template convert<Tscal>() * scale_factor;
94 Tvec delta_cell = (block_max - block_min) / Block::side_size;
95
96 Block::for_each_cell_in_block(delta_cell, [&](u32 lid, Tvec delta) {
97 Tvec bmin = block_min + delta;
98 acc[(i * Block::block_size + lid) * f_nvar + offset]
99 = pos_to_val(bmin, bmin + delta_cell);
100 });
101 }
102
103 f.get_buf().copy_from_stdvec(acc);
104 });
105 }
106
107 inline std::pair<Tvec, Tvec> get_cell_coords(
108 std::pair<TgridVec, TgridVec> block_coords, u32 lid) {
109 using Block = typename Solver::Config::AMRBlock;
110 auto tmp = Block::utils_get_cell_coords(block_coords, lid);
111 tmp.first *= solver.solver_config.grid_coord_to_pos_fact;
112 tmp.second *= solver.solver_config.grid_coord_to_pos_fact;
113 return tmp;
114 }
115
116 inline f64 evolve_once_time_expl(f64 t_curr, f64 dt_input) {
117 return solver.evolve_once_time_expl(t_curr, dt_input);
118 }
119
120 inline void timestep() { solver.evolve_once(); }
121
122 inline void evolve_once() {
123 solver.evolve_once();
124 solver.print_timestep_logs();
125 }
126
127 inline bool evolve_until(Tscal target_time, i32 niter_max) {
128 return solver.evolve_until(target_time, niter_max);
129 }
130
134
135 inline void dump(std::string fname) {
136 if (shamcomm::world_rank() == 0) {
137 logger::info_ln("Godunov", "Dumping state to", fname);
138 }
139
140 nlohmann::json metadata;
141 metadata["solver_config"] = solver.solver_config;
142
144 fname, metadata.dump(4), shambase::get_check_ref(ctx.sched));
145 }
146
152 inline void load_from_dump(std::string fname) {
153 if (shamcomm::world_rank() == 0) {
154 logger::info_ln("Godunov", "Loading state from dump", fname);
155 }
156
157 // Load the context state and recover user metadata
158 std::string metadata_user{};
159 shamrock::load_shamrock_dump(fname, metadata_user, ctx);
160
161 nlohmann::json j = nlohmann::json::parse(metadata_user);
162 j.at("solver_config").get_to(solver.solver_config);
163
164 // modules::GhostZones gz(ctx, solver.solver_config, storage);
165 // gz.build_ghost_cache();
166
167 PatchScheduler &sched = shambase::get_check_ref(ctx.sched);
168 shamlog_debug_ln("Sys", "build local scheduler tables");
169 sched.owned_patch_id = sched.patch_list.build_local();
172 sched.update_local_load_value([&](shamrock::patch::Patch p) {
173 return sched.patch_data.owned_data.get(p.id_patch).get_obj_cnt();
174 });
175 }
176 };
177
178} // namespace shammodels::basegodunov
double f64
Alias for double.
std::uint32_t u32
32 bit unsigned integer
std::uint64_t u64
64 bit unsigned integer
std::int32_t i32
32 bit integer
The MPI scheduler.
SchedulerPatchData patch_data
handle the data of the patches of the scheduler
SchedulerPatchList patch_list
handle the list of the patches of the scheduler
std::unordered_set< u64 > owned_patch_id
(owned_patch_id = patch_list.build_local())
std::unordered_set< u64 > build_local()
select owned patches owned by the node to rebuild local
void build_local_idx_map()
recompute id_patch_to_local_idx
void build_global_idx_map()
recompute id_patch_to_global_idx
A buffer allocated in USM (Unified Shared Memory)
std::vector< T > copy_to_stdvec() const
Copy the content of the buffer to a std::vector.
void init_scheduler(u32 crit_split, u32 crit_merge)
Definition Model.hpp:54
void init()
Initialise the model and all the related data structures (patch scheduler in particular)
Definition Model.cpp:27
void load_from_dump(std::string fname)
Load the state of the Godunov model from a dump file.
Definition Model.hpp:152
u32 get_field_idx(const std::string &field_name) const
Get the field id if matching name & type.
PatchDataLayer container class, the layout is described in patchdata_layout.
shambase::DistributedData< PatchData > owned_data
map container for patchdata owned by the current node (layout : id_patch,data)
T & get_check_ref(const std::unique_ptr< T > &ptr, SourceLocation loc=SourceLocation())
Takes a std::unique_ptr and returns a reference to the object it holds. It throws a std::runtime_erro...
Definition memory.hpp:110
i32 world_rank()
Gives the rank of the current process in the MPI communicator.
Definition worldInfo.cpp:40
namespace for the basegodunov model
void load_shamrock_dump(std::string fname, std::string &metadata_user, ShamrockCtx &ctx)
Load a Shamrock dump file and restore the state of the patches and retreive user metadata.
void write_shamrock_dump(std::string fname, std::string metadata_user, PatchScheduler &sched)
Write a Shamrock dump file containing the current state of the patches and user supplied metadata.
Patch object that contain generic patch information.
Definition Patch.hpp:33