Shamrock 2025.10.0
Astrophysical Code
Loading...
Searching...
No Matches
Model.hpp
Go to the documentation of this file.
1// -------------------------------------------------------//
2//
3// SHAMROCK code for hydrodynamics
4// Copyright (c) 2021-2026 Timothée David--Cléris <tim.shamrock@proton.me>
5// SPDX-License-Identifier: CeCILL Free Software License Agreement v2.1
6// Shamrock is licensed under the CeCILL 2.1 License, see LICENSE for more information
7//
8// -------------------------------------------------------//
9
10#pragma once
11
22#include "shambase/string.hpp"
26#include "shambackends/vec.hpp"
27#include "shamcomm/logs.hpp"
42#include <pybind11/functional.h>
43#include <stdexcept>
44#include <vector>
45
46namespace shammodels::sph {
47
54 template<class Tvec, template<class> class SPHKernel>
55 class Model {
56 public:
57 using Tscal = shambase::VecComponent<Tvec>;
59 using Kernel = SPHKernel<Tscal>;
60
61 using Solver = Solver<Tvec, SPHKernel>;
62 using SolverConfig = typename Solver::Config;
63 // using SolverConfig = typename Solver::Config;
64
65 ShamrockCtx &ctx;
66
67 Solver solver;
68
69 // SolverConfig sconfig;
70
71 Model(ShamrockCtx &ctx) : ctx(ctx), solver(ctx) {};
72
76
78 void init();
79
82 inline void init_scheduler(u32 crit_split, u32 crit_merge) {
83 solver.solver_config.scheduler_conf.split_load_value = crit_split;
84 solver.solver_config.scheduler_conf.merge_load_value = crit_merge;
85 init();
86 }
87
88 template<std::enable_if_t<dim == 3, int> = 0>
89 inline Tvec get_box_dim_fcc_3d(Tscal dr, u32 xcnt, u32 ycnt, u32 zcnt) {
90 return generic::setup::generators::get_box_dim(dr, xcnt, ycnt, zcnt);
91 }
92
93 inline void set_cfl_cour(Tscal cfl_cour) {
94 solver.solver_config.cfl_config.cfl_cour = cfl_cour;
95 }
96 inline void set_cfl_force(Tscal cfl_force) {
97 solver.solver_config.cfl_config.cfl_force = cfl_force;
98 }
99 inline void set_eta_sink(Tscal eta_sink) {
100 solver.solver_config.cfl_config.eta_sink = eta_sink;
101 }
102 inline void set_particle_mass(Tscal gpart_mass) {
103 solver.solver_config.gpart_mass = gpart_mass;
104 }
105
106 inline Tscal get_particle_mass() { return solver.solver_config.gpart_mass; }
107
108 inline void resize_simulation_box(std::pair<Tvec, Tvec> box) {
109 ctx.set_coord_domain_bound({box.first, box.second});
110 }
111
112 SolverConfig gen_config_from_phantom_dump(PhantomDump &phdump, bool bypass_error);
113 void init_from_phantom_dump(PhantomDump &phdump, Tscal hpart_fact_load = 1.0);
114 PhantomDump make_phantom_dump();
115
116 void do_vtk_dump(std::string filename, bool add_patch_world_id) {
117 solver.vtk_do_dump(filename, add_patch_world_id);
118 }
119
120 void set_debug_dump(bool _do_debug_dump, std::string _debug_dump_filename) {
121 solver.set_debug_dump(_do_debug_dump, _debug_dump_filename);
122 }
123
124 u64 get_total_part_count();
125
126 f64 total_mass_to_part_mass(f64 totmass);
127
128 std::pair<Tvec, Tvec> get_ideal_fcc_box(Tscal dr, std::pair<Tvec, Tvec> box);
129 std::pair<Tvec, Tvec> get_ideal_hcp_box(Tscal dr, std::pair<Tvec, Tvec> box);
130
131 Tscal get_hfact() { return Kernel::hfactd; }
132
133 Tscal rho_h(Tscal h) {
134 return shamrock::sph::rho_h(solver.solver_config.gpart_mass, h, Kernel::hfactd);
135 }
136
137 void add_cube_fcc_3d(Tscal dr, std::pair<Tvec, Tvec> _box);
138 void add_cube_hcp_3d(Tscal dr, std::pair<Tvec, Tvec> _box);
139 void add_cube_hcp_3d_v2(Tscal dr, std::pair<Tvec, Tvec> _box);
140
141 inline std::unique_ptr<modules::SPHSetup<Tvec, SPHKernel>> get_setup() {
142 return std::make_unique<modules::SPHSetup<Tvec, SPHKernel>>(
143 ctx, solver.solver_config, solver.storage);
144 }
145
146 // std::function<Tscal(Tscal)> sigma_profile = [=](Tscal r, Tscal r_in, Tscal p){
147 // // we setup with an adimensional mass since it is monte carlo
148 // constexpr Tscal sigma_0 = 1;
149 // return sigma_0*sycl::pow(r/r_in, -p);
150 // };
151 //
152 // std::function<Tscal(Tscal)> cs_law = [=](Tscal r, Tscal r_in, Tscal q){
153 // return sycl::pow(r/r_in, -q);
154 // };
155 //
156 // std::function<Tscal(Tscal)> rot_profile = [=](Tscal r, Tscal central_mass){
157 // Tscal G = solver.solver_config.get_constant_G();
158 // return sycl::sqrt(G * central_mass/r);
159 // };
160 //
161 // std::function<Tscal(Tscal)> cs_profile = [&](Tscal r, Tscal r_in, Tscal H_r_in){
162 // Tscal cs_in = H_r_in*rot_profile(r_in);
163 // return cs_law(r)*cs_in;
164
165 void add_big_disc_3d(
166 Tvec center,
167 Tscal central_mass,
168 u32 Npart,
169 Tscal r_in,
170 Tscal r_out,
171 Tscal disc_mass,
172 Tscal p,
173 Tscal H_r_in,
174 Tscal q,
175 std::mt19937 eng);
176
177 inline void add_sink(Tscal mass, Tvec pos, Tvec velocity, Tscal accretion_radius) {
178 if (solver.storage.sinks.is_empty()) {
179 solver.storage.sinks.set({});
180 }
181
182 shamlog_debug_ln("SPH", "add sink :", mass, pos, velocity, accretion_radius);
183
184 solver.storage.sinks.get().push_back(
185 {pos, velocity, {}, {}, mass, {}, accretion_radius});
186 }
187
188 template<class T>
189 inline void set_field_value_lambda(
190 std::string field_name, const std::function<T(Tvec)> pos_to_val, const u32 offset) {
191
192 StackEntry stack_loc{};
193
194 PatchScheduler &sched = shambase::get_check_ref(ctx.sched);
195
196 u32 ixyz = sched.pdl_old().get_field_idx<Tvec>("xyz");
197 u32 ifield = sched.pdl_old().get_field_idx<T>(field_name);
198
199 sched.patch_data.for_each_patchdata(
200 [&](u64 patch_id, shamrock::patch::PatchDataLayer &pdat) {
201 PatchDataField<Tvec> &xyz = pdat.template get_field<Tvec>(ixyz);
202 PatchDataField<T> &f = pdat.template get_field<T>(ifield);
203
204 auto f_nvar = f.get_nvar();
205 if (offset >= f_nvar) {
207 "offset ({}) is out of bounds for field '{}' with nvar {}",
208 offset,
209 field_name,
210 f_nvar));
211 }
212
213 auto acc = f.get_buf().copy_to_stdvec();
214 auto acc_xyz = xyz.get_buf().copy_to_stdvec();
215
216 u32 obj_cnt = pdat.get_obj_cnt();
217 for (u32 i = 0; i < obj_cnt; i++) {
218 acc[i * f_nvar + offset] = pos_to_val(acc_xyz[i]);
219 }
220
221 f.get_buf().copy_from_stdvec(acc);
222 });
223 }
224
225 template<class T>
226 inline void overwrite_field_value(
227 std::string field_name,
228 const std::function<std::vector<T>(py::dict)> field_compute,
229 const u32 offset) {
230
231 StackEntry stack_loc{};
232
233 PatchScheduler &sched = shambase::get_check_ref(ctx.sched);
234
235 u32 ifield = sched.pdl_old().get_field_idx<T>(field_name);
236
237 sched.patch_data.for_each_patchdata(
238 [&](u64 patch_id, shamrock::patch::PatchDataLayer &pdat) {
239 PatchDataField<T> &f = pdat.template get_field<T>(ifield);
240
241 auto f_nvar = f.get_nvar();
242 if (offset >= f_nvar) {
244 "offset ({}) is out of bounds for field '{}' with nvar {}",
245 offset,
246 field_name,
247 f_nvar));
248 }
249
250 auto result = field_compute(shamrock::pdat_to_dic(pdat));
251
252 if (result.size() != f.get_obj_cnt()) {
254 "result.size() != f.get_obj_cnt() ({} != {})",
255 result.size(),
256 f.get_obj_cnt()));
257 }
258
259 auto acc = f.get_buf().copy_to_stdvec();
260
261 u32 obj_cnt = pdat.get_obj_cnt();
262 for (u32 i = 0; i < obj_cnt; i++) {
263 acc[i * f_nvar + offset] = result[i];
264 }
265
266 f.get_buf().copy_from_stdvec(acc);
267 });
268 }
269
284 template<std::enable_if_t<dim == 3, int> = 0>
285 inline Tscal add_disc_3d(
286 Tvec center,
287 Tscal central_mass,
288 u32 Npart,
289 Tscal r_in,
290 Tscal r_out,
291 Tscal disc_mass,
292 Tscal p,
293 Tscal H_r_in,
294 Tscal q) {
295
296 Tscal G = solver.solver_config.get_constant_G();
297
298 Tscal eos_gamma;
299 using Config = SolverConfig;
300 using SolverConfigEOS = typename Config::EOSConfig;
301 using SolverEOS_Adiabatic = typename SolverConfigEOS::Adiabatic;
302 if (SolverEOS_Adiabatic *eos_config
303 = std::get_if<SolverEOS_Adiabatic>(&solver.solver_config.eos_config.config)) {
304
305 eos_gamma = eos_config->gamma;
306
307 } else {
308 // dirty hack for disc setup in locally isothermal
309 eos_gamma = 2;
310 // shambase::throw_unimplemented();
311 }
312
314
315 auto sigma_profile = [=](Tscal r) {
316 // we setup with an adimensional mass since it is monte carlo
317 constexpr Tscal sigma_0 = 1;
318 return sigma_0 * sycl::pow(r / r_in, -p);
319 };
320
321 auto cs_law = [=](Tscal r) {
322 return sycl::pow(r / r_in, -q);
323 };
324
325 auto rot_profile = [=](Tscal r) {
326 return sycl::sqrt(G * central_mass / r);
327 };
328
329 Tscal cs_in = H_r_in * rot_profile(r_in);
330 auto cs_profile = [&](Tscal r) {
331 return cs_law(r) * cs_in;
332 };
333
334 std::vector<Out> part_list;
335
336 generic::setup::generators::add_disc2<Tscal>(
337 Npart,
338 r_in,
339 r_out,
340 [&](Tscal r) {
341 return sigma_profile(r);
342 },
343 [&](Tscal r) {
344 return cs_profile(r);
345 },
346 [&](Tscal r) {
347 return rot_profile(r);
348 },
349 [&](Out out) {
350 part_list.push_back(out);
351 });
352
353 Tscal part_mass = disc_mass / Npart;
354
355 using namespace shamrock::patch;
356
357 PatchScheduler &sched = shambase::get_check_ref(ctx.sched);
358
359 std::string log = "";
360
361 sched.for_each_local_patchdata([&](const Patch &ptch, PatchDataLayer &pdat) {
362 PatchCoordTransform<Tvec> ptransf = sched.get_sim_box().get_patch_transform<Tvec>();
363
364 shammath::CoordRange<Tvec> patch_coord = ptransf.to_obj_coord(ptch);
365
366 std::vector<Tvec> vec_pos;
367 std::vector<Tvec> vec_vel;
368 std::vector<Tscal> vec_u;
369 std::vector<Tscal> vec_h;
370
371 std::vector<Tscal> vec_cs;
372
373 Tscal G = solver.solver_config.get_constant_G();
374
375 for (Out o : part_list) {
376 vec_pos.push_back(o.pos + center);
377 vec_vel.push_back(o.velocity);
378
379 // for disc with P = \rho u (/gamma - 1)
380 // the scaleheight : H = \sqrt{u (\gamma -1)}/\Omega_K
381 // therefore the effective soundspeed is : \sqrt{(\gamma -1)u}
382 // whereas the real one is \sqrt{(\gamma -1)\gamma u}
383 vec_u.push_back(o.cs * o.cs / (/*solver.eos_gamma * */ (eos_gamma - 1)));
384 vec_h.push_back(shamrock::sph::h_rho(part_mass, o.rho, Kernel::hfactd));
385 vec_cs.push_back(o.cs);
386 }
387
388 log += shambase::format(
389 "\n patch id={}, add N={} particles", ptch.id_patch, vec_pos.size());
390
391 PatchDataLayer tmp(sched.get_layout_ptr_old());
392 tmp.resize(vec_pos.size());
393 tmp.fields_raz();
394
395 {
396 u32 len = vec_pos.size();
398 = tmp.get_field<Tvec>(sched.pdl_old().get_field_idx<Tvec>("xyz"));
399 sycl::buffer<Tvec> buf(vec_pos.data(), len);
400 f.override(buf, len);
401 }
402
403 {
404 u32 len = vec_pos.size();
406 = tmp.get_field<Tscal>(sched.pdl_old().get_field_idx<Tscal>("hpart"));
407 sycl::buffer<Tscal> buf(vec_h.data(), len);
408 f.override(buf, len);
409 }
410
411 {
412 u32 len = vec_pos.size();
414 = tmp.get_field<Tscal>(sched.pdl_old().get_field_idx<Tscal>("uint"));
415 sycl::buffer<Tscal> buf(vec_u.data(), len);
416 f.override(buf, len);
417 }
418
419 if (solver.solver_config.is_eos_locally_isothermal()) {
420 u32 len = vec_pos.size();
422 = tmp.get_field<Tscal>(sched.pdl_old().get_field_idx<Tscal>("soundspeed"));
423 sycl::buffer<Tscal> buf(vec_cs.data(), len);
424 f.override(buf, len);
425 }
426
427 {
428 u32 len = vec_pos.size();
430 = tmp.get_field<Tvec>(sched.pdl_old().get_field_idx<Tvec>("vxyz"));
431 sycl::buffer<Tvec> buf(vec_vel.data(), len);
432 f.override(buf, len);
433 }
434
435 pdat.insert_elements(tmp);
436 });
437
438 std::string log_gathered = "";
439 shamalgs::collective::gather_str(log, log_gathered);
440
441 if (shamcomm::world_rank() == 0) {
442 logger::info_ln("Model", "Push particles : ", log_gathered);
443 }
444
446 ctx, solver.solver_config, solver.storage)
447 .update_load_balancing();
448
449 sched.scheduler_step(false, false);
450
451 {
452 auto [m, M] = sched.get_box_tranform<Tvec>();
453
455 sched.patch_tree, sched.get_sim_box().get_patch_transform<Tvec>());
456
457 // sptree.print_status();
458
460
461 sptree.attach_buf();
462 // reatribute_particles(sched, sptree, periodic_mode);
463
464 reatrib.reatribute_patch_objects(sptree, "xyz");
465 }
466
467 sched.check_patchdata_locality_correctness();
468
469 sched.scheduler_step(true, true);
470
471 log = "";
472 sched.for_each_local_patchdata([&](const Patch &p, PatchDataLayer &pdat) {
473 log += shambase::format(
474 "\n patch id={}, N={} particles", p.id_patch, pdat.get_obj_cnt());
475 });
476
477 log_gathered = "";
478 shamalgs::collective::gather_str(log, log_gathered);
479
480 if (shamcomm::world_rank() == 0)
481 logger::info_ln("Model", "current particle counts : ", log_gathered);
482 return part_mass;
483 }
484
485 template<std::enable_if_t<dim == 3, int> = 0>
486 inline void add_cube_disc_3d(
487 Tvec center,
488 u32 Npart,
489 Tscal p,
490 Tscal rho_0,
491 Tscal m,
492 Tscal r_in,
493 Tscal r_out,
494 Tscal q,
495 Tscal cmass) {
496
497 Tscal eos_gamma;
498 using Config = SolverConfig;
499 using SolverConfigEOS = typename Config::EOSConfig;
500 using SolverEOS_Adiabatic = typename SolverConfigEOS::Adiabatic;
501 if (SolverEOS_Adiabatic *eos_config
502 = std::get_if<SolverEOS_Adiabatic>(&solver.solver_config.eos_config.config)) {
503
504 eos_gamma = eos_config->gamma;
505
506 } else {
508 }
509
510 auto cs = [&](Tscal u) {
511 return sycl::sqrt(eos_gamma * (eos_gamma - 1) * u);
512 };
513
514 auto U = [&](Tscal cs) {
515 return cs * cs / (eos_gamma * (eos_gamma - 1));
516 };
517
518 using namespace shamrock::patch;
519
520 PatchScheduler &sched = shambase::get_check_ref(ctx.sched);
521
522 std::string log = "";
523
524 sched.for_each_local_patchdata([&](const Patch &ptch, PatchDataLayer &pdat) {
525 PatchCoordTransform<Tvec> ptransf = sched.get_sim_box().get_patch_transform<Tvec>();
526
527 shammath::CoordRange<Tvec> patch_coord = ptransf.to_obj_coord(ptch);
528
529 std::vector<Tvec> vec_acc;
530 std::vector<Tvec> vec_vel;
531 std::vector<Tscal> vec_u;
532
533 Tscal G = solver.solver_config.get_constant_G();
534
536 Npart, p, rho_0, m, r_in, r_out, q, [&](Tvec r, Tscal h) {
537 vec_acc.push_back(r + center);
538
539 Tscal R = sycl::length(r);
540
541 Tscal V = sycl::sqrt(G * cmass / R);
542
543 Tvec etheta = {-r.z(), 0, r.x()};
544 etheta /= sycl::length(etheta);
545
546 vec_vel.push_back(V * etheta);
547
548 Tscal cs0 = 1;
549 Tscal cs = cs0 * sycl::pow(R, -q);
550
551 vec_u.push_back(U(cs));
552 });
553
554 log += shambase::format(
555 "\n patch id={}, add N={} particles", ptch.id_patch, vec_acc.size());
556
557 PatchDataLayer tmp(sched.get_layout_ptr_old());
558 tmp.resize(vec_acc.size());
559 tmp.fields_raz();
560
561 {
562 u32 len = vec_acc.size();
564 = tmp.get_field<Tvec>(sched.pdl_old().get_field_idx<Tvec>("xyz"));
565 sycl::buffer<Tvec> buf(vec_acc.data(), len);
566 f.override(buf, len);
567 }
568
569 {
571 = tmp.get_field<Tscal>(sched.pdl_old().get_field_idx<Tscal>("hpart"));
572 f.override(0.01);
573 }
574
575 {
576 u32 len = vec_acc.size();
578 = tmp.get_field<Tscal>(sched.pdl_old().get_field_idx<Tscal>("uint"));
579 sycl::buffer<Tscal> buf(vec_u.data(), len);
580 f.override(buf, len);
581 }
582
583 {
584 u32 len = vec_acc.size();
586 = tmp.get_field<Tvec>(sched.pdl_old().get_field_idx<Tvec>("vxyz"));
587 sycl::buffer<Tvec> buf(vec_vel.data(), len);
588 f.override(buf, len);
589 }
590
591 pdat.insert_elements(tmp);
592 });
593
594 std::string log_gathered = "";
595 shamalgs::collective::gather_str(log, log_gathered);
596
597 if (shamcomm::world_rank() == 0) {
598 logger::info_ln("Model", "Push particles : ", log_gathered);
599 }
600
601 modules::ComputeLoadBalanceValue<Tvec, SPHKernel>(
602 ctx, solver.solver_config, solver.storage)
603 .update_load_balancing();
604
605 sched.scheduler_step(false, false);
606
607 {
608 auto [m, M] = sched.get_box_tranform<Tvec>();
609
611 sched.patch_tree, sched.get_sim_box().get_patch_transform<Tvec>());
612
613 // sptree.print_status();
614
616
617 sptree.attach_buf();
618 // reatribute_particles(sched, sptree, periodic_mode);
619
620 reatrib.reatribute_patch_objects(sptree, "xyz");
621 }
622
623 sched.check_patchdata_locality_correctness();
624
625 sched.scheduler_step(true, true);
626
627 log = "";
628 sched.for_each_local_patchdata([&](const Patch &p, PatchDataLayer &pdat) {
629 log += shambase::format(
630 "\n patch id={}, N={} particles", p.id_patch, pdat.get_obj_cnt());
631 });
632
633 log_gathered = "";
634 shamalgs::collective::gather_str(log, log_gathered);
635
636 if (shamcomm::world_rank() == 0)
637 logger::info_ln("Model", "current particle counts : ", log_gathered);
638 }
639
640 void remap_positions(std::function<Tvec(Tvec)> map);
641
642 void push_particle(
643 std::vector<Tvec> &part_pos_insert,
644 std::vector<Tscal> &part_hpart_insert,
645 std::vector<Tscal> &part_u_insert);
646
647 void push_particle_mhd(
648 std::vector<Tvec> &part_pos_insert,
649 std::vector<Tscal> &part_hpart_insert,
650 std::vector<Tscal> &part_u_insert,
651 std::vector<Tvec> &part_B_on_rho_insert,
652 std::vector<Tscal> &part_psi_on_ch_insert);
653
654 template<class T>
655 inline void set_value_in_a_box(
656 std::string field_name, T val, std::pair<Tvec, Tvec> box, u32 ivar) {
657 StackEntry stack_loc{};
658 PatchScheduler &sched = shambase::get_check_ref(ctx.sched);
659 sched.patch_data.for_each_patchdata(
660 [&](u64 patch_id, shamrock::patch::PatchDataLayer &pdat) {
662 = pdat.template get_field<Tvec>(sched.pdl_old().get_field_idx<Tvec>("xyz"));
663
665 = pdat.template get_field<T>(sched.pdl_old().get_field_idx<T>(field_name));
666
667 if (ivar >= f.get_nvar()) {
669 "You are trying to set value in a box for field ({}) with "
670 "ivar ({}) >= f.get_nvar ({})",
671 field_name,
672 ivar,
673 f.get_nvar()));
674 }
675
676 u32 nvar = f.get_nvar();
677
678 {
679 auto acc = f.get_buf().template mirror_to<sham::host>();
680 auto acc_xyz = xyz.get_buf().template mirror_to<sham::host>();
681
682 for (u32 i = 0; i < f.get_obj_cnt(); i++) {
683 Tvec r = acc_xyz[i];
684
685 if (BBAA::is_coord_in_range(r, std::get<0>(box), std::get<1>(box))) {
686 acc[i * nvar + ivar] = val;
687 }
688 }
689 }
690 });
691 }
692
693 template<class T>
694 inline void set_value_in_sphere(std::string field_name, T val, Tvec center, Tscal radius) {
695 StackEntry stack_loc{};
696 PatchScheduler &sched = shambase::get_check_ref(ctx.sched);
697 sched.patch_data.for_each_patchdata(
698 [&](u64 patch_id, shamrock::patch::PatchDataLayer &pdat) {
700 = pdat.template get_field<Tvec>(sched.pdl_old().get_field_idx<Tvec>("xyz"));
701
703 = pdat.template get_field<T>(sched.pdl_old().get_field_idx<T>(field_name));
704
705 if (f.get_nvar() != 1) {
707 }
708
709 Tscal r2 = radius * radius;
710 {
711 auto acc = f.get_buf().template mirror_to<sham::host>();
712 auto acc_xyz = xyz.get_buf().template mirror_to<sham::host>();
713
714 for (u32 i = 0; i < f.get_obj_cnt(); i++) {
715 Tvec dr = acc_xyz[i] - center;
716
717 if (sycl::dot(dr, dr) < r2) {
718 acc[i] = val;
719 }
720 }
721 }
722 });
723 }
724
725 template<class T>
726 inline void add_kernel_value(std::string field_name, T val, Tvec center, Tscal h_ker) {
727 StackEntry stack_loc{};
728 PatchScheduler &sched = shambase::get_check_ref(ctx.sched);
729 sched.patch_data.for_each_patchdata(
730 [&](u64 patch_id, shamrock::patch::PatchDataLayer &pdat) {
732 = pdat.template get_field<Tvec>(sched.pdl_old().get_field_idx<Tvec>("xyz"));
733
735 = pdat.template get_field<T>(sched.pdl_old().get_field_idx<T>(field_name));
736
737 if (f.get_nvar() != 1) {
739 }
740
741 {
742 auto acc = f.get_buf().template mirror_to<sham::host>();
743 auto acc_xyz = xyz.get_buf().template mirror_to<sham::host>();
744
745 for (u32 i = 0; i < f.get_obj_cnt(); i++) {
746 Tvec dr = acc_xyz[i] - center;
747
748 Tscal r = sycl::length(dr);
749
750 acc[i] += val * Kernel::W_3d(r, h_ker);
751 }
752 }
753 });
754 }
755
756 template<class T>
757 inline T get_sum(std::string name) {
758 PatchScheduler &sched = shambase::get_check_ref(ctx.sched);
760
761 StackEntry stack_loc{};
762 sched.patch_data.for_each_patchdata(
763 [&](u64 patch_id, shamrock::patch::PatchDataLayer &pdat) {
765 = pdat.template get_field<T>(sched.pdl_old().get_field_idx<T>(name));
766
767 sum += xyz.compute_sum();
768 });
769
770 return shamalgs::collective::allreduce_sum(sum);
771 }
772
773 Tvec get_closest_part_to(Tvec pos);
774
775 inline void apply_momentum_offset(Tvec offset) {
776
777 PatchScheduler &sched = shambase::get_check_ref(ctx.sched);
778
779 u32 ivxyz = sched.pdl_old().get_field_idx<Tvec>("vxyz");
780
781 // compute the total mass
782 Tscal tot_mass = 0;
783
784 sched.for_each_patchdata_nonempty(
786 tot_mass += solver.solver_config.gpart_mass * pdat.get_obj_cnt();
787 });
788
789 tot_mass = shamalgs::collective::allreduce_sum(tot_mass);
790
791 // add the mass of the sinks
792 if (!solver.storage.sinks.is_empty()) {
793 for (auto &s : solver.storage.sinks.get()) {
794 tot_mass += s.mass;
795 }
796 }
797
798 // compute the offset velocity
799 Tvec offset_vel = (tot_mass > 0) ? (offset / tot_mass)
801
802 // apply the offset velocity to the sinks
803 if (!solver.storage.sinks.is_empty()) {
804 for (auto &s : solver.storage.sinks.get()) {
805 s.velocity += offset_vel;
806 }
807 }
808
809 // apply the offset velocity to the particles
810 sched.for_each_patchdata_nonempty(
812 PatchDataField<Tvec> &vxyz = pdat.get_field<Tvec>(ivxyz);
813 vxyz.apply_offset(offset_vel);
814 });
815 }
816
817 inline void apply_position_offset(Tvec offset) {
818
819 PatchScheduler &sched = shambase::get_check_ref(ctx.sched);
820
821 u32 ixyz = sched.pdl_old().get_field_idx<Tvec>("xyz");
822
823 // apply the position offset to the sinks
824 if (!solver.storage.sinks.is_empty()) {
825 for (auto &s : solver.storage.sinks.get()) {
826 s.pos += offset;
827 }
828 }
829
830 // apply the position offset to the particles
831 sched.for_each_patchdata_nonempty(
833 PatchDataField<Tvec> &xyz = pdat.get_field<Tvec>(ixyz);
834 xyz.apply_offset(offset);
835 });
836 }
837
838 // inline void enable_barotropic_mode(){
839 // sconfig.enable_barotropic();
840 // }
841 //
842 // inline void switch_internal_energy_mode(std::string name){
843 // sconfig.switch_internal_energy_mode(name);
844 // }
845
846 inline void set_solver_config(typename Solver::Config cfg) {
847 if (ctx.is_scheduler_initialized()) {
849 "Cannot change solver config after scheduler is initialized");
850 }
851 cfg.check_config();
852 solver.solver_config = cfg;
853 }
854
855 inline f64 solver_logs_last_rate() { return solver.solve_logs.get_last_rate(); }
856 inline u64 solver_logs_last_obj_count() { return solver.solve_logs.get_last_obj_count(); }
857 inline f64 solver_logs_cumulated_step_time() {
858 return solver.solve_logs.get_cumulated_step_time();
859 }
860 inline void solver_logs_reset_cumulated_step_time() {
861 solver.solve_logs.reset_cumulated_step_time();
862 }
863 inline u64 solver_logs_step_count() { return solver.solve_logs.get_step_count(); }
864 inline void solver_logs_reset_step_count() { solver.solve_logs.reset_step_count(); }
865
866 inline void change_htolerances(Tscal in_coarse, Tscal in_fine) {
867 if (in_coarse < in_fine) {
869 "in_coarse ({}) must be greater than in_fine ({})", in_coarse, in_fine));
870 }
871 solver.solver_config.htol_up_coarse_cycle = in_coarse;
872 solver.solver_config.htol_up_fine_cycle = in_fine;
873 }
874
878
882
888 inline void load_from_dump(std::string fname) {
889 if (shamcomm::world_rank() == 0) {
890 logger::info_ln("SPH", "Loading state from dump", fname);
891 }
892
893 // Load the context state and recover user metadata
894 std::string metadata_user{};
895 shamrock::load_shamrock_dump(fname, metadata_user, ctx);
896
898 nlohmann::json j = nlohmann::json::parse(metadata_user);
899 // std::cout << j << std::endl;
900 j.at("solver_config").get_to(solver.solver_config);
901
902 if (!j.at("sinks").is_null()) {
903 std::vector<SinkParticle<Tvec>> out;
904 j.at("sinks").get_to(out);
905 solver.storage.sinks.set(std::move(out));
906 }
907
908 solver.init_ghost_layout();
909
910 solver.init_solver_graph();
911
912 PatchScheduler &sched = shambase::get_check_ref(ctx.sched);
913 shamlog_debug_ln("Sys", "build local scheduler tables");
914 sched.owned_patch_id = sched.patch_list.build_local();
917 sched.update_local_load_value([&](shamrock::patch::Patch p) {
918 return sched.patch_data.owned_data.get(p.id_patch).get_obj_cnt();
919 });
920 }
921
927 inline void dump(std::string fname) {
928 if (shamcomm::world_rank() == 0) {
929 logger::info_ln("SPH", "Dumping state to", fname);
930 }
931
932 solver.update_sync_load_values();
933
934 nlohmann::json metadata;
935 metadata["solver_config"] = solver.solver_config;
936
937 if (solver.storage.sinks.is_empty()) {
938 metadata["sinks"] = nlohmann::json{};
939 } else {
940 metadata["sinks"] = solver.storage.sinks.get();
941 }
942
943 // Dump the state of the SPH model to a file
946 fname, metadata.dump(4), shambase::get_check_ref(ctx.sched));
947 }
948
952
953 f64 evolve_once_time_expl(f64 t_curr, f64 dt_input);
954
955 TimestepLog timestep();
956
957 inline void evolve_once() {
958 solver.evolve_once();
959 solver.print_timestep_logs();
960 }
961
962 inline bool evolve_until(Tscal target_time, i32 niter_max) {
963 return solver.evolve_until(target_time, niter_max);
964 }
965
966 private:
967 void add_pdat_to_phantom_block(
968 PhantomDumpBlock &block, shamrock::patch::PatchDataLayer &pdat);
969
970 template<class Tscal>
971 inline void warp_disc(
972 std::vector<Tvec> &pos,
973 std::vector<Tvec> &vel,
974 Tscal posangle,
975 Tscal incl,
976 Tscal Rwarp,
977 Tscal Hwarp) {
978 Tvec k = Tvec(-std::sin(posangle), std::cos(posangle), 0.);
979 Tscal inc;
980 Tscal psi = 0.;
981 u32 len = pos.size();
982
983 // convert to radians (sycl functions take radians)
984 Tscal incl_rad = incl * shambase::constants::pi<Tscal> / 180.;
985
986 for (i32 i = 0; i < len; i++) {
987 Tvec R_vec = pos[i];
988 Tscal R = sycl::sqrt(sycl::dot(R_vec, R_vec));
989 if (R < Rwarp - Hwarp) {
990 inc = 0.;
991 } else if (R < Rwarp + 3. * Hwarp && R > Rwarp - Hwarp) {
992 inc = sycl::asin(
993 0.5
994 * (1.
995 + sycl::sin(shambase::constants::pi<Tscal> / (2. * Hwarp) * (R - Rwarp)))
996 * sycl::sin(incl_rad));
997 psi = shambase::constants::pi<Tscal>
998 * Rwarp / (4. * Hwarp) * sycl::sin(incl_rad)
999 / sycl::sqrt(1. - (0.5 * sycl::pow(sycl::sin(incl_rad), 2)));
1000 Tscal psimax = sycl::max(psimax, psi);
1001 Tscal x = pos[i].x();
1002 Tscal y = pos[i].y();
1003 Tscal z = pos[i].z();
1004
1005 // Tscal xp = x * sycl::cos(inc) + y * sycl::sin(inc);
1006 // Tscal yp = - x * sycl::sin(inc) + y * sycl::cos(inc);
1007 // pos[i] = Tvec(xp, yp, z);
1008
1009 Tvec kk = Tvec(0., 0., 1.);
1010 Tvec w = sycl::cross(kk, pos[i]);
1011 // Rodrigues' rotation formula
1012 pos[i] = pos[i] * sycl::cos(inc) + w * sycl::sin(inc)
1013 + kk * sycl::dot(kk, pos[i]) * (1. - sycl::cos(inc));
1014
1015 } else {
1016 inc = 0.;
1017 }
1018 }
1019 }
1020
1021 inline void rotate_vector(Tvec &u, Tvec &v, Tscal theta) {
1022 // normalize the reference direction
1023 Tvec vunit = v / sycl::sqrt(sycl::dot(v, v));
1024 Tvec w = sycl::cross(vunit, u);
1025 // Rodrigues' rotation formula
1026 u = u * sycl::cos(theta) + w * sycl::sin(theta)
1027 + vunit * sycl::dot(vunit, u) * (1. - sycl::cos(theta));
1028 }
1029 };
1030
1031} // namespace shammodels::sph
constexpr const char * vxyz
3-velocity field
constexpr const char * xyz
Position field (3D coordinates)
Header file describing a Node Instance.
double f64
Alias for double.
std::uint32_t u32
32 bit unsigned integer
std::uint64_t u64
64 bit unsigned integer
std::int32_t i32
32 bit integer
The MPI scheduler.
SchedulerPatchData patch_data
handle the data of the patches of the scheduler
PatchTree patch_tree
handle the tree structure of the patches
void scheduler_step(bool do_split_merge, bool do_load_balancing)
scheduler step
SchedulerPatchList patch_list
handle the list of the patches of the scheduler
std::unordered_set< u64 > owned_patch_id
(owned_patch_id = patch_list.build_local())
std::unordered_set< u64 > build_local()
select owned patches owned by the node to rebuild local
void build_local_idx_map()
recompute id_patch_to_local_idx
void build_global_idx_map()
recompute id_patch_to_global_idx
bool is_scheduler_initialized()
returns true if the scheduler is initialized
The shamrock SPH model.
Definition Model.hpp:55
void load_from_dump(std::string fname)
Load the state of the SPH model from a dump file.
Definition Model.hpp:888
void init()
Initialise the model and all the related data structures (patch scheduler in particular)
Definition Model.cpp:56
Tscal add_disc_3d(Tvec center, Tscal central_mass, u32 Npart, Tscal r_in, Tscal r_out, Tscal disc_mass, Tscal p, Tscal H_r_in, Tscal q)
Add a disc distribution.
Definition Model.hpp:285
void dump(std::string fname)
Dump the state of the SPH model to a file.
Definition Model.hpp:927
void init_scheduler(u32 crit_split, u32 crit_merge)
Definition Model.hpp:82
Utility class used to move the objects between patches.
void reatribute_patch_objects(SerialPatchTree< T > &sptree, std::string position_field)
Reattribute objects based on a given position field.
u32 get_field_idx(const std::string &field_name) const
Get the field id if matching name & type.
PatchDataLayer container class, the layout is described in patchdata_layout.
PatchCoordTransform< T > get_patch_transform() const
Get a PatchCoordTransform object that describes the conversion between patch coordinates and domain c...
Definition SimBox.hpp:285
shambase::DistributedData< PatchData > owned_data
map container for patchdata owned by the current node (layout : id_patch,data)
Class holding the value of numerous constants generated from the following source.
This header file contains utility functions related to exception handling in the code.
MPI string gather / allgather helpers (declarations; implementations in shamalgs/src/collective/gathe...
void gather_str(const std::string &send_vec, std::string &recv_vec)
Gathers a string from all nodes and store the result in a std::string.
void add_disc(u32 Npart, flt p, flt rho_0, flt m, flt r_in, flt r_out, flt q, Tpred_pusher &&part_pusher)
void throw_with_loc(std::string message, SourceLocation loc=SourceLocation{})
Throw an exception and append the source location to it.
T & get_check_ref(const std::unique_ptr< T > &ptr, SourceLocation loc=SourceLocation())
Takes a std::unique_ptr and returns a reference to the object it holds. It throws a std::runtime_erro...
Definition memory.hpp:110
void throw_unimplemented(SourceLocation loc=SourceLocation{})
Throw a std::runtime_error saying that the function is unimplemented.
i32 world_rank()
Gives the rank of the current process in the MPI communicator.
Definition worldInfo.cpp:40
namespace for the sph model
void load_shamrock_dump(std::string fname, std::string &metadata_user, ShamrockCtx &ctx)
Load a Shamrock dump file and restore the state of the patches and retreive user metadata.
void write_shamrock_dump(std::string fname, std::string metadata_user, PatchScheduler &sched)
Write a Shamrock dump file containing the current state of the patches and user supplied metadata.
The configuration for a sph solver.
Patch object that contain generic patch information.
Definition Patch.hpp:33
u64 id_patch
unique key that identify the patch
Definition Patch.hpp:86