57 using Tscal = shambase::VecComponent<Tvec>;
59 using Kernel = SPHKernel<Tscal>;
61 using Solver = Solver<Tvec, SPHKernel>;
83 solver.solver_config.scheduler_conf.split_load_value = crit_split;
84 solver.solver_config.scheduler_conf.merge_load_value = crit_merge;
88 template<std::enable_if_t<dim == 3,
int> = 0>
89 inline Tvec get_box_dim_fcc_3d(Tscal dr,
u32 xcnt,
u32 ycnt,
u32 zcnt) {
90 return generic::setup::generators::get_box_dim(dr, xcnt, ycnt, zcnt);
93 inline void set_cfl_cour(Tscal cfl_cour) {
94 solver.solver_config.cfl_config.cfl_cour = cfl_cour;
96 inline void set_cfl_force(Tscal cfl_force) {
97 solver.solver_config.cfl_config.cfl_force = cfl_force;
99 inline void set_eta_sink(Tscal eta_sink) {
100 solver.solver_config.cfl_config.eta_sink = eta_sink;
102 inline void set_particle_mass(Tscal gpart_mass) {
103 solver.solver_config.gpart_mass = gpart_mass;
106 inline Tscal get_particle_mass() {
return solver.solver_config.gpart_mass; }
108 inline void resize_simulation_box(std::pair<Tvec, Tvec> box) {
109 ctx.set_coord_domain_bound({box.first, box.second});
112 SolverConfig gen_config_from_phantom_dump(PhantomDump &phdump,
bool bypass_error);
113 void init_from_phantom_dump(PhantomDump &phdump, Tscal hpart_fact_load = 1.0);
114 PhantomDump make_phantom_dump();
116 void do_vtk_dump(std::string filename,
bool add_patch_world_id) {
117 solver.vtk_do_dump(filename, add_patch_world_id);
120 void set_debug_dump(
bool _do_debug_dump, std::string _debug_dump_filename) {
121 solver.set_debug_dump(_do_debug_dump, _debug_dump_filename);
124 u64 get_total_part_count();
126 f64 total_mass_to_part_mass(
f64 totmass);
128 std::pair<Tvec, Tvec> get_ideal_fcc_box(Tscal dr, std::pair<Tvec, Tvec> box);
129 std::pair<Tvec, Tvec> get_ideal_hcp_box(Tscal dr, std::pair<Tvec, Tvec> box);
131 Tscal get_hfact() {
return Kernel::hfactd; }
133 Tscal rho_h(Tscal h) {
134 return shamrock::sph::rho_h(solver.solver_config.gpart_mass, h, Kernel::hfactd);
137 void add_cube_fcc_3d(Tscal dr, std::pair<Tvec, Tvec> _box);
138 void add_cube_hcp_3d(Tscal dr, std::pair<Tvec, Tvec> _box);
139 void add_cube_hcp_3d_v2(Tscal dr, std::pair<Tvec, Tvec> _box);
141 inline std::unique_ptr<modules::SPHSetup<Tvec, SPHKernel>> get_setup() {
142 return std::make_unique<modules::SPHSetup<Tvec, SPHKernel>>(
143 ctx, solver.solver_config, solver.storage);
165 void add_big_disc_3d(
177 inline void add_sink(Tscal mass, Tvec pos, Tvec velocity, Tscal accretion_radius) {
178 if (solver.storage.sinks.is_empty()) {
179 solver.storage.sinks.set({});
182 shamlog_debug_ln(
"SPH",
"add sink :", mass, pos, velocity, accretion_radius);
184 solver.storage.sinks.get().push_back(
185 {pos, velocity, {}, {}, mass, {}, accretion_radius});
189 inline void set_field_value_lambda(
190 std::string field_name,
const std::function<T(Tvec)> pos_to_val,
const u32 offset) {
204 auto f_nvar = f.get_nvar();
205 if (offset >= f_nvar) {
207 "offset ({}) is out of bounds for field '{}' with nvar {}",
213 auto acc = f.get_buf().copy_to_stdvec();
214 auto acc_xyz =
xyz.get_buf().copy_to_stdvec();
216 u32 obj_cnt = pdat.get_obj_cnt();
217 for (
u32 i = 0; i < obj_cnt; i++) {
218 acc[i * f_nvar + offset] = pos_to_val(acc_xyz[i]);
221 f.get_buf().copy_from_stdvec(acc);
226 inline void overwrite_field_value(
227 std::string field_name,
228 const std::function<std::vector<T>(py::dict)> field_compute,
241 auto f_nvar = f.get_nvar();
242 if (offset >= f_nvar) {
244 "offset ({}) is out of bounds for field '{}' with nvar {}",
250 auto result = field_compute(shamrock::pdat_to_dic(pdat));
252 if (result.size() != f.get_obj_cnt()) {
254 "result.size() != f.get_obj_cnt() ({} != {})",
259 auto acc = f.get_buf().copy_to_stdvec();
261 u32 obj_cnt = pdat.get_obj_cnt();
262 for (
u32 i = 0; i < obj_cnt; i++) {
263 acc[i * f_nvar + offset] = result[i];
266 f.get_buf().copy_from_stdvec(acc);
284 template<std::enable_if_t<dim == 3,
int> = 0>
296 Tscal G = solver.solver_config.get_constant_G();
299 using Config = SolverConfig;
300 using SolverConfigEOS =
typename Config::EOSConfig;
301 using SolverEOS_Adiabatic =
typename SolverConfigEOS::Adiabatic;
302 if (SolverEOS_Adiabatic *eos_config
303 = std::get_if<SolverEOS_Adiabatic>(&solver.solver_config.eos_config.config)) {
305 eos_gamma = eos_config->gamma;
315 auto sigma_profile = [=](Tscal r) {
317 constexpr Tscal sigma_0 = 1;
318 return sigma_0 * sycl::pow(r / r_in, -p);
321 auto cs_law = [=](Tscal r) {
322 return sycl::pow(r / r_in, -q);
325 auto rot_profile = [=](Tscal r) {
326 return sycl::sqrt(G * central_mass / r);
329 Tscal cs_in = H_r_in * rot_profile(r_in);
330 auto cs_profile = [&](Tscal r) {
331 return cs_law(r) * cs_in;
334 std::vector<Out> part_list;
336 generic::setup::generators::add_disc2<Tscal>(
341 return sigma_profile(r);
344 return cs_profile(r);
347 return rot_profile(r);
350 part_list.push_back(out);
353 Tscal part_mass = disc_mass / Npart;
355 using namespace shamrock::patch;
359 std::string log =
"";
366 std::vector<Tvec> vec_pos;
367 std::vector<Tvec> vec_vel;
368 std::vector<Tscal> vec_u;
369 std::vector<Tscal> vec_h;
371 std::vector<Tscal> vec_cs;
373 Tscal G = solver.solver_config.get_constant_G();
375 for (Out o : part_list) {
376 vec_pos.push_back(o.pos + center);
377 vec_vel.push_back(o.velocity);
383 vec_u.push_back(o.cs * o.cs / ( (eos_gamma - 1)));
384 vec_h.push_back(shamrock::sph::h_rho(part_mass, o.rho, Kernel::hfactd));
385 vec_cs.push_back(o.cs);
388 log += shambase::format(
389 "\n patch id={}, add N={} particles", ptch.
id_patch, vec_pos.size());
392 tmp.resize(vec_pos.size());
396 u32 len = vec_pos.size();
398 = tmp.get_field<Tvec>(sched.pdl_old().
get_field_idx<Tvec>(
"xyz"));
399 sycl::buffer<Tvec> buf(vec_pos.data(), len);
400 f.override(buf, len);
404 u32 len = vec_pos.size();
406 = tmp.get_field<Tscal>(sched.pdl_old().
get_field_idx<Tscal>(
"hpart"));
407 sycl::buffer<Tscal> buf(vec_h.data(), len);
408 f.override(buf, len);
412 u32 len = vec_pos.size();
414 = tmp.get_field<Tscal>(sched.pdl_old().
get_field_idx<Tscal>(
"uint"));
415 sycl::buffer<Tscal> buf(vec_u.data(), len);
416 f.override(buf, len);
419 if (solver.solver_config.is_eos_locally_isothermal()) {
420 u32 len = vec_pos.size();
422 = tmp.get_field<Tscal>(sched.pdl_old().
get_field_idx<Tscal>(
"soundspeed"));
423 sycl::buffer<Tscal> buf(vec_cs.data(), len);
424 f.override(buf, len);
428 u32 len = vec_pos.size();
430 = tmp.get_field<Tvec>(sched.pdl_old().
get_field_idx<Tvec>(
"vxyz"));
431 sycl::buffer<Tvec> buf(vec_vel.data(), len);
432 f.override(buf, len);
435 pdat.insert_elements(tmp);
438 std::string log_gathered =
"";
439 shamalgs::collective::gather_str(log, log_gathered);
442 logger::info_ln(
"Model",
"Push particles : ", log_gathered);
446 ctx, solver.solver_config, solver.storage)
447 .update_load_balancing();
452 auto [m, M] = sched.get_box_tranform<Tvec>();
467 sched.check_patchdata_locality_correctness();
473 log += shambase::format(
474 "\n patch id={}, N={} particles", p.id_patch, pdat.get_obj_cnt());
478 shamalgs::collective::gather_str(log, log_gathered);
481 logger::info_ln(
"Model",
"current particle counts : ", log_gathered);
485 template<std::enable_if_t<dim == 3,
int> = 0>
486 inline void add_cube_disc_3d(
499 using SolverConfigEOS =
typename Config::EOSConfig;
500 using SolverEOS_Adiabatic =
typename SolverConfigEOS::Adiabatic;
501 if (SolverEOS_Adiabatic *eos_config
502 = std::get_if<SolverEOS_Adiabatic>(&solver.solver_config.eos_config.config)) {
504 eos_gamma = eos_config->gamma;
510 auto cs = [&](Tscal u) {
511 return sycl::sqrt(eos_gamma * (eos_gamma - 1) * u);
514 auto U = [&](Tscal cs) {
515 return cs * cs / (eos_gamma * (eos_gamma - 1));
518 using namespace shamrock::patch;
522 std::string log =
"";
529 std::vector<Tvec> vec_acc;
530 std::vector<Tvec> vec_vel;
531 std::vector<Tscal> vec_u;
533 Tscal G = solver.solver_config.get_constant_G();
536 Npart, p, rho_0, m, r_in, r_out, q, [&](Tvec r, Tscal h) {
537 vec_acc.push_back(r + center);
539 Tscal R = sycl::length(r);
541 Tscal V = sycl::sqrt(G * cmass / R);
543 Tvec etheta = {-r.z(), 0, r.x()};
544 etheta /= sycl::length(etheta);
546 vec_vel.push_back(V * etheta);
549 Tscal cs = cs0 * sycl::pow(R, -q);
551 vec_u.push_back(U(cs));
554 log += shambase::format(
555 "\n patch id={}, add N={} particles", ptch.
id_patch, vec_acc.size());
558 tmp.resize(vec_acc.size());
562 u32 len = vec_acc.size();
564 = tmp.get_field<Tvec>(sched.pdl_old().
get_field_idx<Tvec>(
"xyz"));
565 sycl::buffer<Tvec> buf(vec_acc.data(), len);
566 f.override(buf, len);
571 = tmp.get_field<Tscal>(sched.pdl_old().
get_field_idx<Tscal>(
"hpart"));
576 u32 len = vec_acc.size();
578 = tmp.get_field<Tscal>(sched.pdl_old().
get_field_idx<Tscal>(
"uint"));
579 sycl::buffer<Tscal> buf(vec_u.data(), len);
580 f.override(buf, len);
584 u32 len = vec_acc.size();
586 = tmp.get_field<Tvec>(sched.pdl_old().
get_field_idx<Tvec>(
"vxyz"));
587 sycl::buffer<Tvec> buf(vec_vel.data(), len);
588 f.override(buf, len);
591 pdat.insert_elements(tmp);
594 std::string log_gathered =
"";
598 logger::info_ln(
"Model",
"Push particles : ", log_gathered);
601 modules::ComputeLoadBalanceValue<Tvec, SPHKernel>(
602 ctx, solver.solver_config, solver.storage)
603 .update_load_balancing();
608 auto [m, M] = sched.get_box_tranform<Tvec>();
620 reatrib.reatribute_patch_objects(sptree,
"xyz");
623 sched.check_patchdata_locality_correctness();
629 log += shambase::format(
630 "\n patch id={}, N={} particles", p.id_patch, pdat.get_obj_cnt());
637 logger::info_ln(
"Model",
"current particle counts : ", log_gathered);
640 void remap_positions(std::function<Tvec(Tvec)> map);
643 std::vector<Tvec> &part_pos_insert,
644 std::vector<Tscal> &part_hpart_insert,
645 std::vector<Tscal> &part_u_insert);
647 void push_particle_mhd(
648 std::vector<Tvec> &part_pos_insert,
649 std::vector<Tscal> &part_hpart_insert,
650 std::vector<Tscal> &part_u_insert,
651 std::vector<Tvec> &part_B_on_rho_insert,
652 std::vector<Tscal> &part_psi_on_ch_insert);
655 inline void set_value_in_a_box(
656 std::string field_name, T val, std::pair<Tvec, Tvec> box,
u32 ivar) {
662 = pdat.template get_field<Tvec>(sched.pdl_old().
get_field_idx<Tvec>(
"xyz"));
665 = pdat.template get_field<T>(sched.pdl_old().
get_field_idx<T>(field_name));
667 if (ivar >= f.get_nvar()) {
669 "You are trying to set value in a box for field ({}) with "
670 "ivar ({}) >= f.get_nvar ({})",
676 u32 nvar = f.get_nvar();
679 auto acc = f.get_buf().template mirror_to<sham::host>();
680 auto acc_xyz =
xyz.get_buf().template mirror_to<sham::host>();
682 for (
u32 i = 0; i < f.get_obj_cnt(); i++) {
685 if (BBAA::is_coord_in_range(r, std::get<0>(box), std::get<1>(box))) {
686 acc[i * nvar + ivar] = val;
694 inline void set_value_in_sphere(std::string field_name, T val, Tvec center, Tscal radius) {
700 = pdat.template get_field<Tvec>(sched.pdl_old().
get_field_idx<Tvec>(
"xyz"));
703 = pdat.template get_field<T>(sched.pdl_old().
get_field_idx<T>(field_name));
705 if (f.get_nvar() != 1) {
709 Tscal r2 = radius * radius;
711 auto acc = f.get_buf().template mirror_to<sham::host>();
712 auto acc_xyz =
xyz.get_buf().template mirror_to<sham::host>();
714 for (
u32 i = 0; i < f.get_obj_cnt(); i++) {
715 Tvec dr = acc_xyz[i] - center;
717 if (sycl::dot(dr, dr) < r2) {
726 inline void add_kernel_value(std::string field_name, T val, Tvec center, Tscal h_ker) {
732 = pdat.template get_field<Tvec>(sched.pdl_old().
get_field_idx<Tvec>(
"xyz"));
735 = pdat.template get_field<T>(sched.pdl_old().
get_field_idx<T>(field_name));
737 if (f.get_nvar() != 1) {
742 auto acc = f.get_buf().template mirror_to<sham::host>();
743 auto acc_xyz =
xyz.get_buf().template mirror_to<sham::host>();
745 for (
u32 i = 0; i < f.get_obj_cnt(); i++) {
746 Tvec dr = acc_xyz[i] - center;
748 Tscal r = sycl::length(dr);
750 acc[i] += val * Kernel::W_3d(r, h_ker);
757 inline T get_sum(std::string name) {
765 = pdat.template get_field<T>(sched.pdl_old().
get_field_idx<T>(name));
767 sum +=
xyz.compute_sum();
770 return shamalgs::collective::allreduce_sum(sum);
773 Tvec get_closest_part_to(Tvec pos);
775 inline void apply_momentum_offset(Tvec offset) {
784 sched.for_each_patchdata_nonempty(
786 tot_mass += solver.solver_config.gpart_mass * pdat.get_obj_cnt();
789 tot_mass = shamalgs::collective::allreduce_sum(tot_mass);
792 if (!solver.storage.sinks.is_empty()) {
793 for (
auto &s : solver.storage.sinks.get()) {
799 Tvec offset_vel = (tot_mass > 0) ? (offset / tot_mass)
803 if (!solver.storage.sinks.is_empty()) {
804 for (
auto &s : solver.storage.sinks.get()) {
805 s.velocity += offset_vel;
810 sched.for_each_patchdata_nonempty(
813 vxyz.apply_offset(offset_vel);
817 inline void apply_position_offset(Tvec offset) {
824 if (!solver.storage.sinks.is_empty()) {
825 for (
auto &s : solver.storage.sinks.get()) {
831 sched.for_each_patchdata_nonempty(
834 xyz.apply_offset(offset);
846 inline void set_solver_config(
typename Solver::Config cfg) {
849 "Cannot change solver config after scheduler is initialized");
852 solver.solver_config = cfg;
855 inline f64 solver_logs_last_rate() {
return solver.solve_logs.get_last_rate(); }
856 inline u64 solver_logs_last_obj_count() {
return solver.solve_logs.get_last_obj_count(); }
857 inline f64 solver_logs_cumulated_step_time() {
858 return solver.solve_logs.get_cumulated_step_time();
860 inline void solver_logs_reset_cumulated_step_time() {
861 solver.solve_logs.reset_cumulated_step_time();
863 inline u64 solver_logs_step_count() {
return solver.solve_logs.get_step_count(); }
864 inline void solver_logs_reset_step_count() { solver.solve_logs.reset_step_count(); }
866 inline void change_htolerances(Tscal in_coarse, Tscal in_fine) {
867 if (in_coarse < in_fine) {
869 "in_coarse ({}) must be greater than in_fine ({})", in_coarse, in_fine));
871 solver.solver_config.htol_up_coarse_cycle = in_coarse;
872 solver.solver_config.htol_up_fine_cycle = in_fine;
890 logger::info_ln(
"SPH",
"Loading state from dump", fname);
894 std::string metadata_user{};
898 nlohmann::json j = nlohmann::json::parse(metadata_user);
900 j.at(
"solver_config").get_to(solver.solver_config);
902 if (!j.at(
"sinks").is_null()) {
903 std::vector<SinkParticle<Tvec>> out;
904 j.at(
"sinks").get_to(out);
905 solver.storage.sinks.set(std::move(out));
908 solver.init_ghost_layout();
910 solver.init_solver_graph();
913 shamlog_debug_ln(
"Sys",
"build local scheduler tables");
927 inline void dump(std::string fname) {
929 logger::info_ln(
"SPH",
"Dumping state to", fname);
932 solver.update_sync_load_values();
934 nlohmann::json metadata;
935 metadata[
"solver_config"] = solver.solver_config;
937 if (solver.storage.sinks.is_empty()) {
938 metadata[
"sinks"] = nlohmann::json{};
940 metadata[
"sinks"] = solver.storage.sinks.get();
953 f64 evolve_once_time_expl(
f64 t_curr,
f64 dt_input);
957 inline void evolve_once() {
958 solver.evolve_once();
959 solver.print_timestep_logs();
962 inline bool evolve_until(Tscal target_time,
i32 niter_max) {
963 return solver.evolve_until(target_time, niter_max);
967 void add_pdat_to_phantom_block(
970 template<
class Tscal>
971 inline void warp_disc(
972 std::vector<Tvec> &pos,
973 std::vector<Tvec> &vel,
978 Tvec k = Tvec(-std::sin(posangle), std::cos(posangle), 0.);
981 u32 len = pos.size();
984 Tscal incl_rad = incl * shambase::constants::pi<Tscal> / 180.;
986 for (
i32 i = 0; i < len; i++) {
988 Tscal R = sycl::sqrt(sycl::dot(R_vec, R_vec));
989 if (R < Rwarp - Hwarp) {
991 }
else if (R < Rwarp + 3. * Hwarp && R > Rwarp - Hwarp) {
995 + sycl::sin(shambase::constants::pi<Tscal> / (2. * Hwarp) * (R - Rwarp)))
996 * sycl::sin(incl_rad));
997 psi = shambase::constants::pi<Tscal>
998 * Rwarp / (4. * Hwarp) * sycl::sin(incl_rad)
999 / sycl::sqrt(1. - (0.5 * sycl::pow(sycl::sin(incl_rad), 2)));
1000 Tscal psimax = sycl::max(psimax, psi);
1001 Tscal x = pos[i].x();
1002 Tscal y = pos[i].y();
1003 Tscal z = pos[i].z();
1009 Tvec kk = Tvec(0., 0., 1.);
1010 Tvec w = sycl::cross(kk, pos[i]);
1012 pos[i] = pos[i] * sycl::cos(inc) + w * sycl::sin(inc)
1013 + kk * sycl::dot(kk, pos[i]) * (1. - sycl::cos(inc));
1021 inline void rotate_vector(Tvec &u, Tvec &v, Tscal theta) {
1023 Tvec vunit = v / sycl::sqrt(sycl::dot(v, v));
1024 Tvec w = sycl::cross(vunit, u);
1026 u = u * sycl::cos(theta) + w * sycl::sin(theta)
1027 + vunit * sycl::dot(vunit, u) * (1. - sycl::cos(theta));