Shamrock 2025.10.0
Astrophysical Code
Loading...
Searching...
No Matches
Model.hpp
Go to the documentation of this file.
1// -------------------------------------------------------//
2//
3// SHAMROCK code for hydrodynamics
4// Copyright (c) 2021-2026 Timothée David--Cléris <tim.shamrock@proton.me>
5// SPDX-License-Identifier: CeCILL Free Software License Agreement v2.1
6// Shamrock is licensed under the CeCILL 2.1 License, see LICENSE for more information
7//
8// -------------------------------------------------------//
9
10#pragma once
11
19
22#include "shambase/string.hpp"
26#include "shambackends/vec.hpp"
27#include "shamcomm/logs.hpp"
42#include <pybind11/functional.h>
43#include <stdexcept>
44#include <vector>
45
46namespace shammodels::sph {
47
54 template<class Tvec, template<class> class SPHKernel>
55 class Model {
56 public:
57 using Tscal = shambase::VecComponent<Tvec>;
58 static constexpr u32 dim = shambase::VectorProperties<Tvec>::dimension;
59 using Kernel = SPHKernel<Tscal>;
60
61 using Solver = Solver<Tvec, SPHKernel>;
62 using SolverConfig = typename Solver::Config;
63 // using SolverConfig = typename Solver::Config;
64
65 ShamrockCtx &ctx;
66
67 Solver solver;
68
69 // SolverConfig sconfig;
70
71 Model(ShamrockCtx &ctx) : ctx(ctx), solver(ctx) {};
72
76
78 void init();
79
82 inline void init_scheduler(u32 crit_split, u32 crit_merge) {
83 solver.solver_config.scheduler_conf.split_load_value = crit_split;
84 solver.solver_config.scheduler_conf.merge_load_value = crit_merge;
85 init();
86 }
87
88 template<std::enable_if_t<dim == 3, int> = 0>
89 inline Tvec get_box_dim_fcc_3d(Tscal dr, u32 xcnt, u32 ycnt, u32 zcnt) {
90 return generic::setup::generators::get_box_dim(dr, xcnt, ycnt, zcnt);
91 }
92
93 inline void set_cfl_cour(Tscal cfl_cour) {
94 solver.solver_config.cfl_config.cfl_cour = cfl_cour;
95 }
96 inline void set_cfl_force(Tscal cfl_force) {
97 solver.solver_config.cfl_config.cfl_force = cfl_force;
98 }
99 inline void set_eta_sink(Tscal eta_sink) {
100 solver.solver_config.cfl_config.eta_sink = eta_sink;
101 }
102 inline void set_particle_mass(Tscal gpart_mass) {
103 solver.solver_config.gpart_mass = gpart_mass;
104 }
105
106 inline Tscal get_particle_mass() { return solver.solver_config.gpart_mass; }
107
108 inline void resize_simulation_box(std::pair<Tvec, Tvec> box) {
109 ctx.set_coord_domain_bound({box.first, box.second});
110 }
111
112 SolverConfig gen_config_from_phantom_dump(PhantomDump &phdump, bool bypass_error);
113 void init_from_phantom_dump(PhantomDump &phdump, Tscal hpart_fact_load = 1.0);
114 PhantomDump make_phantom_dump();
115
116 void do_vtk_dump(std::string filename, bool add_patch_world_id) {
117 solver.vtk_do_dump(filename, add_patch_world_id);
118 }
119
120 void set_debug_dump(bool _do_debug_dump, std::string _debug_dump_filename) {
121 solver.set_debug_dump(_do_debug_dump, _debug_dump_filename);
122 }
123
124 u64 get_total_part_count();
125
126 f64 total_mass_to_part_mass(f64 totmass);
127
128 Tscal get_hfact() { return Kernel::hfactd; }
129
130 Tscal rho_h(Tscal h) {
131 return shamrock::sph::rho_h(solver.solver_config.gpart_mass, h, Kernel::hfactd);
132 }
133
134 void add_cube_fcc_3d(Tscal dr, std::pair<Tvec, Tvec> _box);
135 void add_cube_hcp_3d(Tscal dr, std::pair<Tvec, Tvec> _box);
136 void add_cube_hcp_3d_v2(Tscal dr, std::pair<Tvec, Tvec> _box);
137
138 inline std::unique_ptr<modules::SPHSetup<Tvec, SPHKernel>> get_setup() {
139 return std::make_unique<modules::SPHSetup<Tvec, SPHKernel>>(
140 ctx, solver.solver_config, solver.storage);
141 }
142
143 // std::function<Tscal(Tscal)> sigma_profile = [=](Tscal r, Tscal r_in, Tscal p){
144 // // we setup with an adimensional mass since it is monte carlo
145 // constexpr Tscal sigma_0 = 1;
146 // return sigma_0*sycl::pow(r/r_in, -p);
147 // };
148 //
149 // std::function<Tscal(Tscal)> cs_law = [=](Tscal r, Tscal r_in, Tscal q){
150 // return sycl::pow(r/r_in, -q);
151 // };
152 //
153 // std::function<Tscal(Tscal)> rot_profile = [=](Tscal r, Tscal central_mass){
154 // Tscal G = solver.solver_config.get_constant_G();
155 // return sycl::sqrt(G * central_mass/r);
156 // };
157 //
158 // std::function<Tscal(Tscal)> cs_profile = [&](Tscal r, Tscal r_in, Tscal H_r_in){
159 // Tscal cs_in = H_r_in*rot_profile(r_in);
160 // return cs_law(r)*cs_in;
161
162 void add_big_disc_3d(
163 Tvec center,
164 Tscal central_mass,
165 u32 Npart,
166 Tscal r_in,
167 Tscal r_out,
168 Tscal disc_mass,
169 Tscal p,
170 Tscal H_r_in,
171 Tscal q,
172 std::mt19937 eng);
173
174 inline void add_sink(Tscal mass, Tvec pos, Tvec velocity, Tscal accretion_radius) {
175 if (solver.storage.sinks.is_empty()) {
176 solver.storage.sinks.set({});
177 }
178
179 shamlog_debug_ln("SPH", "add sink :", mass, pos, velocity, accretion_radius);
180
181 solver.storage.sinks.get().push_back(
182 {pos, velocity, {}, {}, mass, {}, accretion_radius});
183 }
184
185 template<class T>
186 inline void set_field_value_lambda(
187 std::string field_name, const std::function<T(Tvec)> pos_to_val, const u32 offset) {
188
189 StackEntry stack_loc{};
190
191 PatchScheduler &sched = shambase::get_check_ref(ctx.sched);
192
193 u32 ixyz = sched.pdl_old().get_field_idx<Tvec>("xyz");
194 u32 ifield = sched.pdl_old().get_field_idx<T>(field_name);
195
196 sched.patch_data.for_each_patchdata(
197 [&](u64 patch_id, shamrock::patch::PatchDataLayer &pdat) {
198 PatchDataField<Tvec> &xyz = pdat.template get_field<Tvec>(ixyz);
199 PatchDataField<T> &f = pdat.template get_field<T>(ifield);
200
201 auto f_nvar = f.get_nvar();
202 if (offset >= f_nvar) {
204 "offset ({}) is out of bounds for field '{}' with nvar {}",
205 offset,
206 field_name,
207 f_nvar));
208 }
209
210 auto acc = f.get_buf().copy_to_stdvec();
211 auto acc_xyz = xyz.get_buf().copy_to_stdvec();
212
213 u32 obj_cnt = pdat.get_obj_cnt();
214 for (u32 i = 0; i < obj_cnt; i++) {
215 acc[i * f_nvar + offset] = pos_to_val(acc_xyz[i]);
216 }
217
218 f.get_buf().copy_from_stdvec(acc);
219 });
220 }
221
222 template<class T>
223 inline void overwrite_field_value(
224 std::string field_name,
225 const std::function<std::vector<T>(py::dict)> field_compute,
226 const u32 offset) {
227
228 StackEntry stack_loc{};
229
230 PatchScheduler &sched = shambase::get_check_ref(ctx.sched);
231
232 u32 ifield = sched.pdl_old().get_field_idx<T>(field_name);
233
234 sched.patch_data.for_each_patchdata(
235 [&](u64 patch_id, shamrock::patch::PatchDataLayer &pdat) {
236 PatchDataField<T> &f = pdat.template get_field<T>(ifield);
237
238 auto f_nvar = f.get_nvar();
239 if (offset >= f_nvar) {
241 "offset ({}) is out of bounds for field '{}' with nvar {}",
242 offset,
243 field_name,
244 f_nvar));
245 }
246
247 auto result = field_compute(shamrock::pdat_to_dic(pdat));
248
249 if (result.size() != f.get_obj_cnt()) {
251 "result.size() != f.get_obj_cnt() ({} != {})",
252 result.size(),
253 f.get_obj_cnt()));
254 }
255
256 auto acc = f.get_buf().copy_to_stdvec();
257
258 u32 obj_cnt = pdat.get_obj_cnt();
259 for (u32 i = 0; i < obj_cnt; i++) {
260 acc[i * f_nvar + offset] = result[i];
261 }
262
263 f.get_buf().copy_from_stdvec(acc);
264 });
265 }
266
281 template<std::enable_if_t<dim == 3, int> = 0>
282 inline Tscal add_disc_3d(
283 Tvec center,
284 Tscal central_mass,
285 u32 Npart,
286 Tscal r_in,
287 Tscal r_out,
288 Tscal disc_mass,
289 Tscal p,
290 Tscal H_r_in,
291 Tscal q) {
292
293 Tscal G = solver.solver_config.get_constant_G();
294
295 Tscal eos_gamma;
296 using Config = SolverConfig;
297 using SolverConfigEOS = typename Config::EOSConfig;
298 using SolverEOS_Adiabatic = typename SolverConfigEOS::Adiabatic;
299 if (SolverEOS_Adiabatic *eos_config
300 = std::get_if<SolverEOS_Adiabatic>(&solver.solver_config.eos_config.config)) {
301
302 eos_gamma = eos_config->gamma;
303
304 } else {
305 // dirty hack for disc setup in locally isothermal
306 eos_gamma = 2;
307 // shambase::throw_unimplemented();
308 }
309
311
312 auto sigma_profile = [=](Tscal r) {
313 // we setup with an adimensional mass since it is monte carlo
314 constexpr Tscal sigma_0 = 1;
315 return sigma_0 * sycl::pow(r / r_in, -p);
316 };
317
318 auto cs_law = [=](Tscal r) {
319 return sycl::pow(r / r_in, -q);
320 };
321
322 auto rot_profile = [=](Tscal r) {
323 return sycl::sqrt(G * central_mass / r);
324 };
325
326 Tscal cs_in = H_r_in * rot_profile(r_in);
327 auto cs_profile = [&](Tscal r) {
328 return cs_law(r) * cs_in;
329 };
330
331 std::vector<Out> part_list;
332
334 Npart,
335 r_in,
336 r_out,
337 [&](Tscal r) {
338 return sigma_profile(r);
339 },
340 [&](Tscal r) {
341 return cs_profile(r);
342 },
343 [&](Tscal r) {
344 return rot_profile(r);
345 },
346 [&](Out out) {
347 part_list.push_back(out);
348 });
349
350 Tscal part_mass = disc_mass / Npart;
351
352 using namespace shamrock::patch;
353
354 PatchScheduler &sched = shambase::get_check_ref(ctx.sched);
355
356 std::string log = "";
357
358 sched.for_each_local_patchdata([&](const Patch &ptch, PatchDataLayer &pdat) {
359 PatchCoordTransform<Tvec> ptransf = sched.get_sim_box().get_patch_transform<Tvec>();
360
361 shammath::CoordRange<Tvec> patch_coord = ptransf.to_obj_coord(ptch);
362
363 std::vector<Tvec> vec_pos;
364 std::vector<Tvec> vec_vel;
365 std::vector<Tscal> vec_u;
366 std::vector<Tscal> vec_h;
367
368 std::vector<Tscal> vec_cs;
369
370 Tscal G = solver.solver_config.get_constant_G();
371
372 for (Out o : part_list) {
373 vec_pos.push_back(o.pos + center);
374 vec_vel.push_back(o.velocity);
375
376 // for disc with P = \rho u (/gamma - 1)
377 // the scaleheight : H = \sqrt{u (\gamma -1)}/\Omega_K
378 // therefore the effective soundspeed is : \sqrt{(\gamma -1)u}
379 // whereas the real one is \sqrt{(\gamma -1)\gamma u}
380 vec_u.push_back(o.cs * o.cs / (/*solver.eos_gamma * */ (eos_gamma - 1)));
381 vec_h.push_back(shamrock::sph::h_rho(part_mass, o.rho, Kernel::hfactd));
382 vec_cs.push_back(o.cs);
383 }
384
385 log += shambase::format(
386 "\n patch id={}, add N={} particles", ptch.id_patch, vec_pos.size());
387
388 PatchDataLayer tmp(sched.get_layout_ptr_old());
389 tmp.resize(vec_pos.size());
390 tmp.fields_raz();
391
392 {
393 u32 len = vec_pos.size();
395 = tmp.get_field<Tvec>(sched.pdl_old().get_field_idx<Tvec>("xyz"));
396 sycl::buffer<Tvec> buf(vec_pos.data(), len);
397 f.override(buf, len);
398 }
399
400 {
401 u32 len = vec_pos.size();
403 = tmp.get_field<Tscal>(sched.pdl_old().get_field_idx<Tscal>("hpart"));
404 sycl::buffer<Tscal> buf(vec_h.data(), len);
405 f.override(buf, len);
406 }
407
408 {
409 u32 len = vec_pos.size();
411 = tmp.get_field<Tscal>(sched.pdl_old().get_field_idx<Tscal>("uint"));
412 sycl::buffer<Tscal> buf(vec_u.data(), len);
413 f.override(buf, len);
414 }
415
416 if (solver.solver_config.is_eos_locally_isothermal()) {
417 u32 len = vec_pos.size();
419 = tmp.get_field<Tscal>(sched.pdl_old().get_field_idx<Tscal>("soundspeed"));
420 sycl::buffer<Tscal> buf(vec_cs.data(), len);
421 f.override(buf, len);
422 }
423
424 {
425 u32 len = vec_pos.size();
427 = tmp.get_field<Tvec>(sched.pdl_old().get_field_idx<Tvec>("vxyz"));
428 sycl::buffer<Tvec> buf(vec_vel.data(), len);
429 f.override(buf, len);
430 }
431
432 pdat.insert_elements(tmp);
433 });
434
435 std::string log_gathered = "";
436 shamalgs::collective::gather_str(log, log_gathered);
437
438 if (shamcomm::world_rank() == 0) {
439 logger::info_ln("Model", "Push particles : ", log_gathered);
440 }
441
443 ctx, solver.solver_config, solver.storage)
444 .update_load_balancing();
445
446 sched.scheduler_step(false, false);
447
448 {
449 auto [m, M] = sched.get_box_tranform<Tvec>();
450
452 sched.patch_tree, sched.get_sim_box().get_patch_transform<Tvec>());
453
454 // sptree.print_status();
455
457
458 sptree.attach_buf();
459 // reatribute_particles(sched, sptree, periodic_mode);
460
461 reatrib.reatribute_patch_objects(sptree, "xyz");
462 }
463
464 sched.check_patchdata_locality_correctness();
465
466 sched.scheduler_step(true, true);
467
468 log = "";
469 sched.for_each_local_patchdata([&](const Patch &p, PatchDataLayer &pdat) {
470 log += shambase::format(
471 "\n patch id={}, N={} particles", p.id_patch, pdat.get_obj_cnt());
472 });
473
474 log_gathered = "";
475 shamalgs::collective::gather_str(log, log_gathered);
476
477 if (shamcomm::world_rank() == 0)
478 logger::info_ln("Model", "current particle counts : ", log_gathered);
479 return part_mass;
480 }
481
482 template<std::enable_if_t<dim == 3, int> = 0>
483 inline void add_cube_disc_3d(
484 Tvec center,
485 u32 Npart,
486 Tscal p,
487 Tscal rho_0,
488 Tscal m,
489 Tscal r_in,
490 Tscal r_out,
491 Tscal q,
492 Tscal cmass) {
493
494 Tscal eos_gamma;
495 using Config = SolverConfig;
496 using SolverConfigEOS = typename Config::EOSConfig;
497 using SolverEOS_Adiabatic = typename SolverConfigEOS::Adiabatic;
498 if (SolverEOS_Adiabatic *eos_config
499 = std::get_if<SolverEOS_Adiabatic>(&solver.solver_config.eos_config.config)) {
500
501 eos_gamma = eos_config->gamma;
502
503 } else {
505 }
506
507 auto cs = [&](Tscal u) {
508 return sycl::sqrt(eos_gamma * (eos_gamma - 1) * u);
509 };
510
511 auto U = [&](Tscal cs) {
512 return cs * cs / (eos_gamma * (eos_gamma - 1));
513 };
514
515 using namespace shamrock::patch;
516
517 PatchScheduler &sched = shambase::get_check_ref(ctx.sched);
518
519 std::string log = "";
520
521 sched.for_each_local_patchdata([&](const Patch &ptch, PatchDataLayer &pdat) {
522 PatchCoordTransform<Tvec> ptransf = sched.get_sim_box().get_patch_transform<Tvec>();
523
524 shammath::CoordRange<Tvec> patch_coord = ptransf.to_obj_coord(ptch);
525
526 std::vector<Tvec> vec_acc;
527 std::vector<Tvec> vec_vel;
528 std::vector<Tscal> vec_u;
529
530 Tscal G = solver.solver_config.get_constant_G();
531
533 Npart, p, rho_0, m, r_in, r_out, q, [&](Tvec r, Tscal h) {
534 vec_acc.push_back(r + center);
535
536 Tscal R = sycl::length(r);
537
538 Tscal V = sycl::sqrt(G * cmass / R);
539
540 Tvec etheta = {-r.z(), 0, r.x()};
541 etheta /= sycl::length(etheta);
542
543 vec_vel.push_back(V * etheta);
544
545 Tscal cs0 = 1;
546 Tscal cs = cs0 * sycl::pow(R, -q);
547
548 vec_u.push_back(U(cs));
549 });
550
551 log += shambase::format(
552 "\n patch id={}, add N={} particles", ptch.id_patch, vec_acc.size());
553
554 PatchDataLayer tmp(sched.get_layout_ptr_old());
555 tmp.resize(vec_acc.size());
556 tmp.fields_raz();
557
558 {
559 u32 len = vec_acc.size();
560 PatchDataField<Tvec> &f
561 = tmp.get_field<Tvec>(sched.pdl_old().get_field_idx<Tvec>("xyz"));
562 sycl::buffer<Tvec> buf(vec_acc.data(), len);
563 f.override(buf, len);
564 }
565
566 {
567 PatchDataField<Tscal> &f
568 = tmp.get_field<Tscal>(sched.pdl_old().get_field_idx<Tscal>("hpart"));
569 f.override(0.01);
570 }
571
572 {
573 u32 len = vec_acc.size();
574 PatchDataField<Tscal> &f
575 = tmp.get_field<Tscal>(sched.pdl_old().get_field_idx<Tscal>("uint"));
576 sycl::buffer<Tscal> buf(vec_u.data(), len);
577 f.override(buf, len);
578 }
579
580 {
581 u32 len = vec_acc.size();
582 PatchDataField<Tvec> &f
583 = tmp.get_field<Tvec>(sched.pdl_old().get_field_idx<Tvec>("vxyz"));
584 sycl::buffer<Tvec> buf(vec_vel.data(), len);
585 f.override(buf, len);
586 }
587
588 pdat.insert_elements(tmp);
589 });
590
591 std::string log_gathered = "";
592 shamalgs::collective::gather_str(log, log_gathered);
593
594 if (shamcomm::world_rank() == 0) {
595 logger::info_ln("Model", "Push particles : ", log_gathered);
596 }
597
598 modules::ComputeLoadBalanceValue<Tvec, SPHKernel>(
599 ctx, solver.solver_config, solver.storage)
600 .update_load_balancing();
601
602 sched.scheduler_step(false, false);
603
604 {
605 auto [m, M] = sched.get_box_tranform<Tvec>();
606
607 SerialPatchTree<Tvec> sptree(
608 sched.patch_tree, sched.get_sim_box().get_patch_transform<Tvec>());
609
610 // sptree.print_status();
611
612 shamrock::ReattributeDataUtility reatrib(sched);
613
614 sptree.attach_buf();
615 // reatribute_particles(sched, sptree, periodic_mode);
616
617 reatrib.reatribute_patch_objects(sptree, "xyz");
618 }
619
620 sched.check_patchdata_locality_correctness();
621
622 sched.scheduler_step(true, true);
623
624 log = "";
625 sched.for_each_local_patchdata([&](const Patch &p, PatchDataLayer &pdat) {
626 log += shambase::format(
627 "\n patch id={}, N={} particles", p.id_patch, pdat.get_obj_cnt());
628 });
629
630 log_gathered = "";
631 shamalgs::collective::gather_str(log, log_gathered);
632
633 if (shamcomm::world_rank() == 0)
634 logger::info_ln("Model", "current particle counts : ", log_gathered);
635 }
636
637 void remap_positions(std::function<Tvec(Tvec)> map);
638
639 void push_particle(
640 std::vector<Tvec> &part_pos_insert,
641 std::vector<Tscal> &part_hpart_insert,
642 std::vector<Tscal> &part_u_insert);
643
644 void push_particle_mhd(
645 std::vector<Tvec> &part_pos_insert,
646 std::vector<Tscal> &part_hpart_insert,
647 std::vector<Tscal> &part_u_insert,
648 std::vector<Tvec> &part_B_on_rho_insert,
649 std::vector<Tscal> &part_psi_on_ch_insert);
650
651 template<class T>
652 inline void set_value_in_a_box(
653 std::string field_name, T val, std::pair<Tvec, Tvec> box, u32 ivar) {
654 StackEntry stack_loc{};
655 PatchScheduler &sched = shambase::get_check_ref(ctx.sched);
656 sched.patch_data.for_each_patchdata(
657 [&](u64 patch_id, shamrock::patch::PatchDataLayer &pdat) {
658 PatchDataField<Tvec> &xyz
659 = pdat.template get_field<Tvec>(sched.pdl_old().get_field_idx<Tvec>("xyz"));
660
661 PatchDataField<T> &f
662 = pdat.template get_field<T>(sched.pdl_old().get_field_idx<T>(field_name));
663
664 if (ivar >= f.get_nvar()) {
666 "You are trying to set value in a box for field ({}) with "
667 "ivar ({}) >= f.get_nvar ({})",
668 field_name,
669 ivar,
670 f.get_nvar()));
671 }
672
673 u32 nvar = f.get_nvar();
674
675 {
676 auto acc = f.get_buf().template mirror_to<sham::host>();
677 auto acc_xyz = xyz.get_buf().template mirror_to<sham::host>();
678
679 for (u32 i = 0; i < f.get_obj_cnt(); i++) {
680 Tvec r = acc_xyz[i];
681
682 if (BBAA::is_coord_in_range(r, std::get<0>(box), std::get<1>(box))) {
683 acc[i * nvar + ivar] = val;
684 }
685 }
686 }
687 });
688 }
689
690 template<class T>
691 inline void set_value_in_sphere(std::string field_name, T val, Tvec center, Tscal radius) {
692 StackEntry stack_loc{};
693 PatchScheduler &sched = shambase::get_check_ref(ctx.sched);
694 sched.patch_data.for_each_patchdata(
695 [&](u64 patch_id, shamrock::patch::PatchDataLayer &pdat) {
696 PatchDataField<Tvec> &xyz
697 = pdat.template get_field<Tvec>(sched.pdl_old().get_field_idx<Tvec>("xyz"));
698
699 PatchDataField<T> &f
700 = pdat.template get_field<T>(sched.pdl_old().get_field_idx<T>(field_name));
701
702 if (f.get_nvar() != 1) {
704 }
705
706 Tscal r2 = radius * radius;
707 {
708 auto acc = f.get_buf().template mirror_to<sham::host>();
709 auto acc_xyz = xyz.get_buf().template mirror_to<sham::host>();
710
711 for (u32 i = 0; i < f.get_obj_cnt(); i++) {
712 Tvec dr = acc_xyz[i] - center;
713
714 if (sycl::dot(dr, dr) < r2) {
715 acc[i] = val;
716 }
717 }
718 }
719 });
720 }
721
722 template<class T>
723 inline void add_kernel_value(std::string field_name, T val, Tvec center, Tscal h_ker) {
724 StackEntry stack_loc{};
725 PatchScheduler &sched = shambase::get_check_ref(ctx.sched);
726 sched.patch_data.for_each_patchdata(
727 [&](u64 patch_id, shamrock::patch::PatchDataLayer &pdat) {
728 PatchDataField<Tvec> &xyz
729 = pdat.template get_field<Tvec>(sched.pdl_old().get_field_idx<Tvec>("xyz"));
730
731 PatchDataField<T> &f
732 = pdat.template get_field<T>(sched.pdl_old().get_field_idx<T>(field_name));
733
734 if (f.get_nvar() != 1) {
736 }
737
738 {
739 auto acc = f.get_buf().template mirror_to<sham::host>();
740 auto acc_xyz = xyz.get_buf().template mirror_to<sham::host>();
741
742 for (u32 i = 0; i < f.get_obj_cnt(); i++) {
743 Tvec dr = acc_xyz[i] - center;
744
745 Tscal r = sycl::length(dr);
746
747 acc[i] += val * Kernel::W_3d(r, h_ker);
748 }
749 }
750 });
751 }
752
753 template<class T>
754 inline T get_sum(std::string name) {
755 PatchScheduler &sched = shambase::get_check_ref(ctx.sched);
756 T sum = shambase::VectorProperties<T>::get_zero();
757
758 StackEntry stack_loc{};
759 sched.patch_data.for_each_patchdata(
760 [&](u64 patch_id, shamrock::patch::PatchDataLayer &pdat) {
761 PatchDataField<T> &xyz
762 = pdat.template get_field<T>(sched.pdl_old().get_field_idx<T>(name));
763
764 sum += xyz.compute_sum();
765 });
766
767 return shamalgs::collective::allreduce_sum(sum);
768 }
769
770 Tvec get_closest_part_to(Tvec pos);
771
772 inline void apply_momentum_offset(Tvec offset) {
773
774 PatchScheduler &sched = shambase::get_check_ref(ctx.sched);
775
776 u32 ivxyz = sched.pdl_old().get_field_idx<Tvec>("vxyz");
777
778 // compute the total mass
779 Tscal tot_mass = 0;
780
781 sched.for_each_patchdata_nonempty(
782 [&](shamrock::patch::Patch p, shamrock::patch::PatchDataLayer &pdat) {
783 tot_mass += solver.solver_config.gpart_mass * pdat.get_obj_cnt();
784 });
785
786 tot_mass = shamalgs::collective::allreduce_sum(tot_mass);
787
788 // add the mass of the sinks
789 if (!solver.storage.sinks.is_empty()) {
790 for (auto &s : solver.storage.sinks.get()) {
791 tot_mass += s.mass;
792 }
793 }
794
795 // compute the offset velocity
796 Tvec offset_vel = (tot_mass > 0) ? (offset / tot_mass)
797 : shambase::VectorProperties<Tvec>::get_zero();
798
799 // apply the offset velocity to the sinks
800 if (!solver.storage.sinks.is_empty()) {
801 for (auto &s : solver.storage.sinks.get()) {
802 s.velocity += offset_vel;
803 }
804 }
805
806 // apply the offset velocity to the particles
807 sched.for_each_patchdata_nonempty(
808 [&](shamrock::patch::Patch p, shamrock::patch::PatchDataLayer &pdat) {
809 PatchDataField<Tvec> &vxyz = pdat.get_field<Tvec>(ivxyz);
810 vxyz.apply_offset(offset_vel);
811 });
812 }
813
814 inline void apply_position_offset(Tvec offset) {
815
816 PatchScheduler &sched = shambase::get_check_ref(ctx.sched);
817
818 u32 ixyz = sched.pdl_old().get_field_idx<Tvec>("xyz");
819
820 // apply the position offset to the sinks
821 if (!solver.storage.sinks.is_empty()) {
822 for (auto &s : solver.storage.sinks.get()) {
823 s.pos += offset;
824 }
825 }
826
827 // apply the position offset to the particles
828 sched.for_each_patchdata_nonempty(
829 [&](shamrock::patch::Patch p, shamrock::patch::PatchDataLayer &pdat) {
830 PatchDataField<Tvec> &xyz = pdat.get_field<Tvec>(ixyz);
831 xyz.apply_offset(offset);
832 });
833 }
834
835 // inline void enable_barotropic_mode(){
836 // sconfig.enable_barotropic();
837 // }
838 //
839 // inline void switch_internal_energy_mode(std::string name){
840 // sconfig.switch_internal_energy_mode(name);
841 // }
842
843 inline void set_solver_config(typename Solver::Config cfg) {
844 if (ctx.is_scheduler_initialized()) {
846 "Cannot change solver config after scheduler is initialized");
847 }
848 cfg.check_config();
849 solver.solver_config = cfg;
850 }
851
852 inline f64 solver_logs_last_rate() { return solver.solve_logs.get_last_rate(); }
853 inline u64 solver_logs_last_obj_count() { return solver.solve_logs.get_last_obj_count(); }
854 inline f64 solver_logs_cumulated_step_time() {
855 return solver.solve_logs.get_cumulated_step_time();
856 }
857 inline void solver_logs_reset_cumulated_step_time() {
858 solver.solve_logs.reset_cumulated_step_time();
859 }
860 inline u64 solver_logs_step_count() { return solver.solve_logs.get_step_count(); }
861 inline void solver_logs_reset_step_count() { solver.solve_logs.reset_step_count(); }
862
863 inline void change_htolerances(Tscal in_coarse, Tscal in_fine) {
864 if (in_coarse < in_fine) {
866 "in_coarse ({}) must be greater than in_fine ({})", in_coarse, in_fine));
867 }
868 solver.solver_config.htol_up_coarse_cycle = in_coarse;
869 solver.solver_config.htol_up_fine_cycle = in_fine;
870 }
871
875
879
885 inline void load_from_dump(std::string fname) {
886 if (shamcomm::world_rank() == 0) {
887 logger::info_ln("SPH", "Loading state from dump", fname);
888 }
889
890 // Load the context state and recover user metadata
891 std::string metadata_user{};
892 shamrock::load_shamrock_dump(fname, metadata_user, ctx);
893
895 nlohmann::json j = nlohmann::json::parse(metadata_user);
896 // std::cout << j << std::endl;
897 j.at("solver_config").get_to(solver.solver_config);
898
899 if (!j.at("sinks").is_null()) {
900 std::vector<SinkParticle<Tvec>> out;
901 j.at("sinks").get_to(out);
902 solver.storage.sinks.set(std::move(out));
903 }
904
905 solver.init_ghost_layout();
906
907 solver.init_solver_graph();
908
909 PatchScheduler &sched = shambase::get_check_ref(ctx.sched);
910 shamlog_debug_ln("Sys", "build local scheduler tables");
911 sched.owned_patch_id = sched.patch_list.build_local();
914 sched.update_local_load_value([&](shamrock::patch::Patch p) {
915 return sched.patch_data.owned_data.get(p.id_patch).get_obj_cnt();
916 });
917 }
918
924 inline void dump(std::string fname) {
925 if (shamcomm::world_rank() == 0) {
926 logger::info_ln("SPH", "Dumping state to", fname);
927 }
928
929 solver.update_sync_load_values();
930
931 nlohmann::json metadata;
932 metadata["solver_config"] = solver.solver_config;
933
934 if (solver.storage.sinks.is_empty()) {
935 metadata["sinks"] = nlohmann::json{};
936 } else {
937 metadata["sinks"] = solver.storage.sinks.get();
938 }
939
940 // Dump the state of the SPH model to a file
943 fname, metadata.dump(4), shambase::get_check_ref(ctx.sched));
944 }
945
949
950 f64 evolve_once_time_expl(f64 t_curr, f64 dt_input);
951
952 TimestepLog timestep();
953
954 inline void evolve_once() {
955 solver.evolve_once();
956 solver.print_timestep_logs();
957 }
958
959 inline EvolveUntilResults evolve_until(
960 Tscal target_time, i32 niter_max, f64 max_walltime = -1) {
961 return solver.evolve_until(target_time, niter_max, max_walltime);
962 }
963
964 private:
965 void add_pdat_to_phantom_block(
966 PhantomDumpBlock &block, shamrock::patch::PatchDataLayer &pdat);
967
968 template<class Tscal>
969 inline void warp_disc(
970 std::vector<Tvec> &pos,
971 std::vector<Tvec> &vel,
972 Tscal posangle,
973 Tscal incl,
974 Tscal Rwarp,
975 Tscal Hwarp) {
976 Tvec k = Tvec(-std::sin(posangle), std::cos(posangle), 0.);
977 Tscal inc;
978 Tscal psi = 0.;
979 u32 len = pos.size();
980
981 // convert to radians (sycl functions take radians)
982 Tscal incl_rad = incl * shambase::constants::pi<Tscal> / 180.;
983
984 for (i32 i = 0; i < len; i++) {
985 Tvec R_vec = pos[i];
986 Tscal R = sycl::sqrt(sycl::dot(R_vec, R_vec));
987 if (R < Rwarp - Hwarp) {
988 inc = 0.;
989 } else if (R < Rwarp + 3. * Hwarp && R > Rwarp - Hwarp) {
990 inc = sycl::asin(
991 0.5
992 * (1.
993 + sycl::sin(shambase::constants::pi<Tscal> / (2. * Hwarp) * (R - Rwarp)))
994 * sycl::sin(incl_rad));
995 psi = shambase::constants::pi<Tscal>
996 * Rwarp / (4. * Hwarp) * sycl::sin(incl_rad)
997 / sycl::sqrt(1. - (0.5 * sycl::pow(sycl::sin(incl_rad), 2)));
998 Tscal psimax = sycl::max(psimax, psi);
999 Tscal x = pos[i].x();
1000 Tscal y = pos[i].y();
1001 Tscal z = pos[i].z();
1002
1003 // Tscal xp = x * sycl::cos(inc) + y * sycl::sin(inc);
1004 // Tscal yp = - x * sycl::sin(inc) + y * sycl::cos(inc);
1005 // pos[i] = Tvec(xp, yp, z);
1006
1007 Tvec kk = Tvec(0., 0., 1.);
1008 Tvec w = sycl::cross(kk, pos[i]);
1009 // Rodrigues' rotation formula
1010 pos[i] = pos[i] * sycl::cos(inc) + w * sycl::sin(inc)
1011 + kk * sycl::dot(kk, pos[i]) * (1. - sycl::cos(inc));
1012
1013 } else {
1014 inc = 0.;
1015 }
1016 }
1017 }
1018
1019 inline void rotate_vector(Tvec &u, Tvec &v, Tscal theta) {
1020 // normalize the reference direction
1021 Tvec vunit = v / sycl::sqrt(sycl::dot(v, v));
1022 Tvec w = sycl::cross(vunit, u);
1023 // Rodrigues' rotation formula
1024 u = u * sycl::cos(theta) + w * sycl::sin(theta)
1025 + vunit * sycl::dot(vunit, u) * (1. - sycl::cos(theta));
1026 }
1027 };
1028
1029} // namespace shammodels::sph
constexpr const char * vxyz
3-velocity field
constexpr const char * xyz
Position field (3D coordinates).
Header file describing a Node Instance.
double f64
Alias for double.
std::uint32_t u32
32 bit unsigned integer
std::uint64_t u64
64 bit unsigned integer
std::int32_t i32
32 bit integer
The MPI scheduler.
SchedulerPatchData patch_data
handle the data of the patches of the scheduler
PatchTree patch_tree
handle the tree structure of the patches
void scheduler_step(bool do_split_merge, bool do_load_balancing)
scheduler step
SchedulerPatchList patch_list
handle the list of the patches of the scheduler
std::unordered_set< u64 > owned_patch_id
(owned_patch_id = patch_list.build_local())
std::unordered_set< u64 > build_local()
select owned patches owned by the node to rebuild local
void build_local_idx_map()
recompute id_patch_to_local_idx
void build_global_idx_map()
recompute id_patch_to_global_idx
void load_from_dump(std::string fname)
Load the state of the SPH model from a dump file.
Definition Model.hpp:885
void init()
Initialise the model and all the related data structures (patch scheduler in particular).
Definition Model.cpp:56
Tscal add_disc_3d(Tvec center, Tscal central_mass, u32 Npart, Tscal r_in, Tscal r_out, Tscal disc_mass, Tscal p, Tscal H_r_in, Tscal q)
Add a disc distribution.
Definition Model.hpp:282
void dump(std::string fname)
Dump the state of the SPH model to a file.
Definition Model.hpp:924
void init_scheduler(u32 crit_split, u32 crit_merge)
Definition Model.hpp:82
Utility class used to move the objects between patches.
void reatribute_patch_objects(SerialPatchTree< T > &sptree, std::string position_field)
Reattribute objects based on a given position field.
u32 get_field_idx(const std::string &field_name) const
Get the field id if matching name & type.
PatchDataLayer container class, the layout is described in patchdata_layout.
PatchCoordTransform< T > get_patch_transform() const
Get a PatchCoordTransform object that describes the conversion between patch coordinates and domain c...
Definition SimBox.hpp:285
shambase::DistributedData< PatchData > owned_data
map container for patchdata owned by the current node (layout : id_patch,data)
Class holding the value of numerous constants generated from the following source.
This header file contains utility functions related to exception handling in the code.
MPI string gather / allgather helpers (declarations; implementations in shamalgs/src/collective/gathe...
void gather_str(const std::string &send_vec, std::string &recv_vec)
Gathers a string from all nodes and store the result in a std::string.
void add_disc2(u32 Npart, flt r_in, flt r_out, std::function< flt(flt)> sigma_profile, std::function< flt(flt)> cs_profile, std::function< flt(flt)> rot_profile, std::function< void(DiscOutput< flt >)> pusher)
void add_disc(u32 Npart, flt p, flt rho_0, flt m, flt r_in, flt r_out, flt q, Tpred_pusher &&part_pusher)
void throw_with_loc(std::string message, SourceLocation loc=SourceLocation{})
Throw an exception and append the source location to it.
T & get_check_ref(const std::unique_ptr< T > &ptr, SourceLocation loc=SourceLocation())
Takes a std::unique_ptr and returns a reference to the object it holds. It throws a std::runtime_erro...
Definition memory.hpp:110
ExcptTypes make_except_with_loc(std::string message, SourceLocation loc=SourceLocation{})
Create an exception with a message and a location.
void throw_unimplemented(SourceLocation loc=SourceLocation{})
Throw a std::runtime_error saying that the function is unimplemented.
i32 world_rank()
Gives the rank of the current process in the MPI communicator.
Definition worldInfo.cpp:40
namespace for the sph model
void load_shamrock_dump(std::string fname, std::string &metadata_user, ShamrockCtx &ctx)
Load a Shamrock dump file and restore the state of the patches and retreive user metadata.
void write_shamrock_dump(std::string fname, std::string metadata_user, PatchScheduler &sched)
Write a Shamrock dump file containing the current state of the patches and user supplied metadata.
void info_ln(std::string module_name, Types... var2)
Prints a log message with multiple arguments followed by a newline.
Definition logs.hpp:133
shambase::details::BasicStackEntry StackEntry
Alias for shambase::details::BasicStackEntry.
The configuration for a sph solver.
Patch object that contain generic patch information.
Definition Patch.hpp:33
u64 id_patch
unique key that identify the patch
Definition Patch.hpp:86