65 using Tscal = shambase::VecComponent<Tvec>;
67 using Kernel = SPHKernel<Tscal>;
69 using Solver = Solver<Tvec, SPHKernel>;
87 solver.solver_config.scheduler_conf.split_load_value = crit_split;
88 solver.solver_config.scheduler_conf.merge_load_value = crit_merge;
92 template<std::enable_if_t<dim == 3,
int> = 0>
93 inline Tvec get_box_dim_fcc_3d(Tscal dr,
u32 xcnt,
u32 ycnt,
u32 zcnt) {
94 return generic::setup::generators::get_box_dim(dr, xcnt, ycnt, zcnt);
97 inline void set_cfl_cour(Tscal cfl_cour) {
98 solver.solver_config.cfl_config.cfl_cour = cfl_cour;
101 inline void set_cfl_force(Tscal cfl_force) {
102 solver.solver_config.cfl_config.cfl_force = cfl_force;
105 inline void set_particle_mass(Tscal gpart_mass) {
106 solver.solver_config.gpart_mass = gpart_mass;
109 inline Tscal get_particle_mass() {
return solver.solver_config.gpart_mass; }
111 inline void resize_simulation_box(std::pair<Tvec, Tvec> box) {
112 ctx.set_coord_domain_bound({box.first, box.second});
115 void do_vtk_dump(std::string filename,
bool add_patch_world_id) {
116 solver.vtk_do_dump(filename, add_patch_world_id);
119 u64 get_total_part_count();
121 f64 total_mass_to_part_mass(
f64 totmass);
123 std::pair<Tvec, Tvec> get_ideal_fcc_box(Tscal dr, std::pair<Tvec, Tvec> box);
124 std::pair<Tvec, Tvec> get_ideal_hcp_box(Tscal dr, std::pair<Tvec, Tvec> box);
126 Tscal get_hfact() {
return Kernel::hfactd; }
128 Tscal rho_h(Tscal h) {
129 return shamrock::sph::rho_h(solver.solver_config.gpart_mass, h, Kernel::hfactd);
132 void add_cube_fcc_3d(Tscal dr, std::pair<Tvec, Tvec> _box);
133 void add_cube_hcp_3d(Tscal dr, std::pair<Tvec, Tvec> _box);
159 std::string field_name,
const std::function<T(Tvec)> pos_to_val) {
166 = pdat.template get_field<Tvec>(sched.pdl_old().
get_field_idx<Tvec>(
"xyz"));
169 = pdat.template get_field<T>(sched.pdl_old().
get_field_idx<T>(field_name));
171 if (f.get_nvar() != 1) {
176 auto &buf = f.get_buf();
177 auto acc = buf.copy_to_stdvec();
179 auto &buf_xyz = xyz.get_buf();
180 auto acc_xyz = buf_xyz.copy_to_stdvec();
182 for (
u32 i = 0; i < f.get_obj_cnt(); i++) {
184 acc[i] = pos_to_val(r);
187 buf.copy_from_stdvec(acc);
188 buf_xyz.copy_from_stdvec(acc_xyz);
216 std::string field_name, T val, std::pair<Tvec, Tvec> box,
u32 ivar = 0) {
222 = pdat.template get_field<Tvec>(sched.pdl_old().
get_field_idx<Tvec>(
"xyz"));
225 = pdat.template get_field<T>(sched.pdl_old().
get_field_idx<T>(field_name));
227 u32 nvar = f.get_nvar();
232 "set_field_in_box: ivar ({}) >= f.get_nvar ({}) for field {}",
239 auto acc = f.get_buf().template mirror_to<sham::host>();
240 auto acc_xyz = xyz.get_buf().template mirror_to<sham::host>();
242 for (
u32 i = 0; i < f.get_obj_cnt(); i++) {
245 if (BBAA::is_coord_in_range(r, std::get<0>(box), std::get<1>(box))) {
246 acc[i * nvar + ivar] = val;
280 = pdat.template get_field<Tvec>(sched.pdl_old().
get_field_idx<Tvec>(
"xyz"));
283 = pdat.template get_field<T>(sched.pdl_old().
get_field_idx<T>(field_name));
285 if (f.get_nvar() != 1) {
289 Tscal r2 = radius * radius;
291 auto acc = f.get_buf().template mirror_to<sham::host>();
292 auto acc_xyz = xyz.get_buf().template mirror_to<sham::host>();
294 for (
u32 i = 0; i < f.get_obj_cnt(); i++) {
295 Tvec dr = acc_xyz[i] - center;
297 if (sycl::dot(dr, dr) < r2) {
306 inline T get_sum(std::string name) {
314 = pdat.template get_field<T>(sched.pdl_old().
get_field_idx<T>(name));
316 sum += xyz.compute_sum();
319 return shamalgs::collective::allreduce_sum(sum);
326 inline SolverConfig gen_default_config() {
328 cfg.set_riemann_iterative();
329 cfg.set_reconstruct_piecewise_constant();
330 cfg.set_eos_adiabatic(Tscal{1.4});
331 cfg.set_boundary_periodic();
335 inline void set_solver_config(SolverConfig cfg) {
338 "Cannot change solver config after scheduler is initialized");
341 solver.solver_config = cfg;
344 inline f64 solver_logs_last_rate() {
return solver.solve_logs.get_last_rate(); }
345 inline u64 solver_logs_last_obj_count() {
return solver.solve_logs.get_last_obj_count(); }
347 return solver.solve_logs.get_last_system_metrics();
354 inline void load_from_dump(std::string fname) {
356 logger::info_ln(
"GSPH",
"Loading state from dump", fname);
359 std::string metadata_user{};
362 nlohmann::json j = nlohmann::json::parse(metadata_user);
363 j.at(
"solver_config").get_to(solver.solver_config);
365 solver.init_ghost_layout();
366 solver.init_solver_graph();
377 inline void dump(std::string fname) {
379 logger::info_ln(
"GSPH",
"Dumping state to", fname);
382 solver.update_sync_load_values();
384 nlohmann::json metadata;
385 metadata[
"solver_config"] = solver.solver_config;
395 TimestepLog timestep() {
return solver.evolve_once(); }
397 inline void evolve_once() {
398 solver.evolve_once();
399 solver.print_timestep_logs();
402 inline bool evolve_until(Tscal target_time,
i32 niter_max = -1) {
403 return solver.evolve_until(target_time, niter_max);