Shamrock 2025.10.0
Astrophysical Code
Loading...
Searching...
No Matches
UpdateDerivs.cpp
Go to the documentation of this file.
1// -------------------------------------------------------//
2//
3// SHAMROCK code for hydrodynamics
4// Copyright (c) 2021-2026 Timothée David--Cléris <tim.shamrock@proton.me>
5// SPDX-License-Identifier: CeCILL Free Software License Agreement v2.1
6// Shamrock is licensed under the CeCILL 2.1 License, see LICENSE for more information
7//
8// -------------------------------------------------------//
9
17
19#include "shambase/memory.hpp"
22#include "shambackends/math.hpp"
23#include "shamcomm/logs.hpp"
40#include "shamphys/mhd.hpp"
46#include <memory>
47#include <vector>
48
49template<class Tvec, template<class> class SPHKernel>
50void shammodels::sph::modules::UpdateDerivs<Tvec, SPHKernel>::update_derivs(Tscal dt_hydro) {
51
52 Cfg_AV cfg_av = solver_config.artif_viscosity;
53 Cfg_MHD cfg_mhd = solver_config.mhd_config;
54 DustConfig cfg_dust = solver_config.dust_config;
55
56 if (Constant *v = std::get_if<Constant>(&cfg_av.config)) {
57 update_derivs_constantAV(*v);
58 } else if (VaryingMM97 *v = std::get_if<VaryingMM97>(&cfg_av.config)) {
59 update_derivs_mm97(*v);
60 } else if (VaryingCD10 *v = std::get_if<VaryingCD10>(&cfg_av.config)) {
61 update_derivs_cd10(*v);
62 } else if (ConstantDisc *v = std::get_if<ConstantDisc>(&cfg_av.config)) {
63 update_derivs_disc_visco(*v);
64 } else if (IdealMHD *v = std::get_if<IdealMHD>(&cfg_mhd.config)) {
65 update_derivs_MHD(*v);
66 } else if (NonIdealMHD *v = std::get_if<NonIdealMHD>(&cfg_mhd.config)) {
68 } else if (NoneMHD *v = std::get_if<NoneMHD>(&cfg_mhd.config)) {
70 } else if (None *v = std::get_if<None>(&cfg_av.config)) {
72 } else {
74 }
75
76 if (cfg_dust.has_s_j_field()) {
77 // we can do it separately because the backreaction is done only through the pressure
78 update_derivs_dust_monofluid_tvi_Sj(cfg_dust, dt_hydro);
79 }
80}
81
82template<class Tvec, template<class> class SPHKernel>
83void shammodels::sph::modules::UpdateDerivs<Tvec, SPHKernel>::update_derivs_noAV(None cfg) {}
84
85template<class Tvec, template<class> class SPHKernel>
86void shammodels::sph::modules::UpdateDerivs<Tvec, SPHKernel>::update_derivs_constantAV(
87 Constant cfg) {
88 StackEntry stack_loc{};
89
90 using namespace shamrock;
91 using namespace shamrock::patch;
92
93 PatchDataLayerLayout &pdl = scheduler().pdl_old();
94
95 const u32 ixyz = pdl.get_field_idx<Tvec>("xyz");
96 const u32 ivxyz = pdl.get_field_idx<Tvec>("vxyz");
97 const u32 iaxyz = pdl.get_field_idx<Tvec>("axyz");
98 const u32 iuint = pdl.get_field_idx<Tscal>("uint");
99 const u32 iduint = pdl.get_field_idx<Tscal>("duint");
100 const u32 ihpart = pdl.get_field_idx<Tscal>("hpart");
101
103 = shambase::get_check_ref(storage.ghost_layout.get());
104 u32 ihpart_interf = ghost_layout.get_field_idx<Tscal>("hpart");
105 u32 iuint_interf = ghost_layout.get_field_idx<Tscal>("uint");
106 u32 ivxyz_interf = ghost_layout.get_field_idx<Tvec>("vxyz");
107 u32 iomega_interf = ghost_layout.get_field_idx<Tscal>("omega");
108
109 auto &merged_xyzh = storage.merged_xyzh.get();
111 shambase::DistributedData<PatchDataLayer> &mpdats = storage.merged_patchdata_ghost.get();
112
113 scheduler().for_each_patchdata_nonempty([&](Patch cur_p, PatchDataLayer &pdat) {
114 PatchDataLayer &mpdat = mpdats.get(cur_p.id_patch);
115
117 = merged_xyzh.get(cur_p.id_patch).template get_field_buf_ref<Tvec>(0);
118 sham::DeviceBuffer<Tvec> &buf_axyz = pdat.get_field_buf_ref<Tvec>(iaxyz);
119 sham::DeviceBuffer<Tscal> &buf_duint = pdat.get_field_buf_ref<Tscal>(iduint);
120 sham::DeviceBuffer<Tvec> &buf_vxyz = mpdat.get_field_buf_ref<Tvec>(ivxyz_interf);
121 sham::DeviceBuffer<Tscal> &buf_hpart = mpdat.get_field_buf_ref<Tscal>(ihpart_interf);
122 sham::DeviceBuffer<Tscal> &buf_omega = mpdat.get_field_buf_ref<Tscal>(iomega_interf);
123 sham::DeviceBuffer<Tscal> &buf_uint = mpdat.get_field_buf_ref<Tscal>(iuint_interf);
124 sham::DeviceBuffer<Tscal> &buf_pressure
125 = shambase::get_check_ref(storage.pressure).get_field(cur_p.id_patch).get_buf();
127 = shambase::get_check_ref(storage.soundspeed).get_field(cur_p.id_patch).get_buf();
128
129 sycl::range range_npart{pdat.get_obj_cnt()};
130
131 tree::ObjectCache &pcache
132 = shambase::get_check_ref(storage.neigh_cache).get_cache(cur_p.id_patch);
133
135
136 sham::DeviceQueue &q = shamsys::instance::get_compute_scheduler().get_queue();
137 sham::EventList depends_list;
138
139 auto xyz = buf_xyz.get_read_access(depends_list);
140 auto axyz = buf_axyz.get_write_access(depends_list);
141 auto du = buf_duint.get_write_access(depends_list);
142 auto vxyz = buf_vxyz.get_read_access(depends_list);
143 auto hpart = buf_hpart.get_read_access(depends_list);
144 auto omega = buf_omega.get_read_access(depends_list);
145 auto u = buf_uint.get_read_access(depends_list); // TODO rename to uint
146 auto pressure = buf_pressure.get_read_access(depends_list);
147 auto cs = buf_cs.get_read_access(depends_list);
148 auto ploop_ptrs = pcache.get_read_access(depends_list);
149
150 auto e = q.submit(depends_list, [&](sycl::handler &cgh) {
151 const Tscal pmass = solver_config.gpart_mass;
152 const Tscal alpha_u = cfg.alpha_u;
153 const Tscal alpha_AV = cfg.alpha_AV;
154 const Tscal beta_AV = cfg.beta_AV;
155
156 shamlog_debug_sycl_ln("deriv kernel", "alpha_u :", alpha_u);
157 shamlog_debug_sycl_ln("deriv kernel", "alpha_AV :", alpha_AV);
158 shamlog_debug_sycl_ln("deriv kernel", "beta_AV :", beta_AV);
159
160 // tree::ObjectIterator particle_looper(tree,cgh);
161
162 // tree::LeafCacheObjectIterator
163 // particle_looper(tree,*xyz_cell_id,leaf_cache,cgh);
164
165 tree::ObjectCacheIterator particle_looper(ploop_ptrs);
166
167 // sycl::accessor hmax_tree{tree_field_hmax, cgh, sycl::read_only};
168
169 // sycl::stream out {4096,1024,cgh};
170
171 constexpr Tscal Rker2 = Kernel::Rkern * Kernel::Rkern;
172
173 shambase::parallel_for(cgh, pdat.get_obj_cnt(), "compute force cte AV", [=](u64 gid) {
174 u32 id_a = (u32) gid;
175
176 using namespace shamrock::sph;
177
178 Tvec sum_axyz = {0, 0, 0};
179 Tscal sum_du_a = 0;
180
181 Tscal h_a = hpart[id_a];
182 Tvec xyz_a = xyz[id_a];
183 Tvec vxyz_a = vxyz[id_a];
184 Tscal P_a = pressure[id_a];
185 Tscal omega_a = omega[id_a];
186 const Tscal u_a = u[id_a];
187
188 Tscal rho_a = rho_h(pmass, h_a, Kernel::hfactd);
189 Tscal rho_a_sq = rho_a * rho_a;
190 Tscal rho_a_inv = 1. / rho_a;
191
192 // f32 P_a = cs * cs * rho_a;
193
194 Tscal omega_a_rho_a_inv = 1 / (omega_a * rho_a);
195
196 Tscal cs_a = cs[id_a];
197
198 Tvec force_pressure{0, 0, 0};
199 Tscal tmpdU_pressure = 0;
200
201 particle_looper.for_each_object(id_a, [&](u32 id_b) {
202 // compute only omega_a
203 Tvec dr = xyz_a - xyz[id_b];
204 Tscal rab2 = sycl::dot(dr, dr);
205 Tscal h_b = hpart[id_b];
206
207 if (rab2 > h_a * h_a * Rker2 && rab2 > h_b * h_b * Rker2) {
208 return;
209 }
210
211 Tscal rab = sycl::sqrt(rab2);
212 Tvec vxyz_b = vxyz[id_b];
213 const Tscal u_b = u[id_b];
214
215 Tscal rho_b = rho_h(pmass, h_b, Kernel::hfactd);
216 Tscal P_b = pressure[id_b];
217 // f32 P_b = cs * cs * rho_b;
218 Tscal omega_b = omega[id_b];
219 Tscal cs_b = cs[id_b];
220
221 const Tscal alpha_a = alpha_AV;
222 const Tscal alpha_b = alpha_AV;
223
224 Tscal Fab_a = Kernel::dW_3d(rab, h_a);
225 Tscal Fab_b = Kernel::dW_3d(rab, h_b);
226
227 Tvec v_ab = vxyz_a - vxyz_b;
228
229 Tvec r_ab_unit = dr * sham::inv_sat_positive(rab);
230
231 // f32 P_b = cs * cs * rho_b;
232 Tscal v_ab_r_ab = sycl::dot(v_ab, r_ab_unit);
233 Tscal abs_v_ab_r_ab = sycl::fabs(v_ab_r_ab);
234
235 Tscal vsig_a = alpha_a * cs_a + beta_AV * abs_v_ab_r_ab;
236 Tscal vsig_b = alpha_b * cs_b + beta_AV * abs_v_ab_r_ab;
237
238 Tscal vsig_u = shamrock::sph::vsig_u(P_a, P_b, rho_a, rho_b);
239
240 Tscal qa_ab = shamrock::sph::q_av(rho_a, vsig_a, v_ab_r_ab);
241 Tscal qb_ab = shamrock::sph::q_av(rho_b, vsig_b, v_ab_r_ab);
242
243 add_to_derivs_sph_artif_visco_cond(
244 pmass,
245 rho_a_sq,
246 omega_a_rho_a_inv,
247 rho_a_inv,
248 rho_b,
249 omega_a,
250 omega_b,
251 Fab_a,
252 Fab_b,
253 u_a,
254 u_b,
255 P_a,
256 P_b,
257 alpha_u,
258 v_ab,
259 r_ab_unit,
260 vsig_u,
261 qa_ab,
262 qb_ab,
263 force_pressure,
264 tmpdU_pressure);
265 });
266 axyz[id_a] = force_pressure;
267 du[id_a] = tmpdU_pressure;
268 });
269 });
270
271 buf_xyz.complete_event_state(e);
272 buf_axyz.complete_event_state(e);
273 buf_duint.complete_event_state(e);
274 buf_vxyz.complete_event_state(e);
275 buf_hpart.complete_event_state(e);
276 buf_omega.complete_event_state(e);
277 buf_uint.complete_event_state(e);
278 buf_pressure.complete_event_state(e);
279 buf_cs.complete_event_state(e);
280
281 sham::EventList resulting_events;
282 resulting_events.add_event(e);
283 pcache.complete_event_state(resulting_events);
284 });
285}
286template<class Tvec, template<class> class SPHKernel>
287void shammodels::sph::modules::UpdateDerivs<Tvec, SPHKernel>::update_derivs_mm97(VaryingMM97 cfg) {
288 StackEntry stack_loc{};
289
290 using namespace shamrock;
291 using namespace shamrock::patch;
292
293 PatchDataLayerLayout &pdl = scheduler().pdl_old();
294
295 const u32 ixyz = pdl.get_field_idx<Tvec>("xyz");
296 const u32 ivxyz = pdl.get_field_idx<Tvec>("vxyz");
297 const u32 iaxyz = pdl.get_field_idx<Tvec>("axyz");
298 const u32 iuint = pdl.get_field_idx<Tscal>("uint");
299 const u32 iduint = pdl.get_field_idx<Tscal>("duint");
300 const u32 ihpart = pdl.get_field_idx<Tscal>("hpart");
301
303 = shambase::get_check_ref(storage.ghost_layout.get());
304 u32 ihpart_interf = ghost_layout.get_field_idx<Tscal>("hpart");
305 u32 iuint_interf = ghost_layout.get_field_idx<Tscal>("uint");
306 u32 ivxyz_interf = ghost_layout.get_field_idx<Tvec>("vxyz");
307 u32 iomega_interf = ghost_layout.get_field_idx<Tscal>("omega");
308
309 auto &merged_xyzh = storage.merged_xyzh.get();
311 shambase::DistributedData<PatchDataLayer> &mpdats = storage.merged_patchdata_ghost.get();
312
313 auto &part_counts = storage.part_counts;
314 auto &part_counts_with_ghost = storage.part_counts_with_ghost;
315 auto &xyz_refs = storage.positions_with_ghosts;
316 auto &pressure_field = storage.pressure;
317 auto &soundspeed_field = storage.soundspeed;
318
319 std::shared_ptr<shamrock::solvergraph::FieldRefs<Tscal>> uint_refs
320 = std::make_shared<shamrock::solvergraph::FieldRefs<Tscal>>("uint", "u");
321 {
322 shambase::get_check_ref(uint_refs).set_refs(
323 mpdats.map<std::reference_wrapper<PatchDataField<Tscal>>>(
324 [&](u64 id, shamrock::patch::PatchDataLayer &mpdat) {
325 return std::ref(mpdat.get_field<Tscal>(iuint_interf));
326 }));
327 }
328
329 std::shared_ptr<shamrock::solvergraph::FieldRefs<Tvec>> vxyz_refs
330 = std::make_shared<shamrock::solvergraph::FieldRefs<Tvec>>("vxyz", "v");
331 {
332 shambase::get_check_ref(vxyz_refs).set_refs(
333 mpdats.map<std::reference_wrapper<PatchDataField<Tvec>>>(
334 [&](u64 id, shamrock::patch::PatchDataLayer &mpdat) {
335 return std::ref(mpdat.get_field<Tvec>(ivxyz_interf));
336 }));
337 }
338
339 std::shared_ptr<shamrock::solvergraph::FieldRefs<Tscal>> hpart_refs
340 = std::make_shared<shamrock::solvergraph::FieldRefs<Tscal>>("hpart", "h");
341 { // if was just reset before this call
342 shambase::get_check_ref(hpart_refs)
343 .set_refs(mpdats.map<std::reference_wrapper<PatchDataField<Tscal>>>(
344 [&](u64 id, shamrock::patch::PatchDataLayer &mpdat) {
345 return std::ref(mpdat.get_field<Tscal>(ihpart_interf));
346 }));
347 }
348
349 std::shared_ptr<shamrock::solvergraph::FieldRefs<Tscal>> omega_refs
350 = std::make_shared<shamrock::solvergraph::FieldRefs<Tscal>>("omega", "omega");
351 {
352 shambase::get_check_ref(omega_refs)
353 .set_refs(mpdats.map<std::reference_wrapper<PatchDataField<Tscal>>>(
354 [&](u64 id, shamrock::patch::PatchDataLayer &mpdat) {
355 return std::ref(mpdat.get_field<Tscal>(iomega_interf));
356 }));
357 }
358
359 std::shared_ptr<shamrock::solvergraph::FieldRefs<Tscal>> alpha_av_refs
360 = std::make_shared<shamrock::solvergraph::FieldRefs<Tscal>>("alpha_av", "alpha_av");
361 {
363 scheduler().for_each_patchdata_nonempty([&](Patch cur_p, PatchDataLayer &pdat) {
364 refs.add_obj(
365 cur_p.id_patch, std::ref(storage.alpha_av_ghost.get().get(cur_p.id_patch)));
366 });
367 shambase::get_check_ref(alpha_av_refs).set_refs(refs);
368 }
369
370 shamrock::solvergraph::SolverGraph &solver_graph = storage.solver_graph;
371
372 auto axyz_refs = solver_graph.get_edge_ptr<shamrock::solvergraph::FieldRefs<Tvec>>("axyz");
373 auto duint_refs = solver_graph.get_edge_ptr<shamrock::solvergraph::FieldRefs<Tscal>>("duint");
374 auto gpart_mass
375 = solver_graph.get_edge_ptr<shamrock::solvergraph::ScalarEdge<Tscal>>("gpart_mass");
376
377 std::shared_ptr<shamrock::solvergraph::ScalarEdge<Tscal>> alpha_u
378 = std::make_shared<shamrock::solvergraph::ScalarEdge<Tscal>>("alpha_u", "alpha_u");
379 {
380 shambase::get_check_ref(alpha_u).value = cfg.alpha_u;
381 }
382 std::shared_ptr<shamrock::solvergraph::ScalarEdge<Tscal>> beta_AV
383 = std::make_shared<shamrock::solvergraph::ScalarEdge<Tscal>>("beta_AV", "beta_AV");
384 {
385 shambase::get_check_ref(beta_AV).value = cfg.beta_AV;
386 }
387
388 std::shared_ptr<NodeUpdateDerivsVaryingAlphaAV<Tvec, SPHKernel>> node
389 = std::make_shared<NodeUpdateDerivsVaryingAlphaAV<Tvec, SPHKernel>>();
390 {
391 node->set_edges(
392 gpart_mass,
393 alpha_u,
394 beta_AV,
395 part_counts,
396 part_counts_with_ghost,
397 xyz_refs,
398 hpart_refs,
399 vxyz_refs,
400 uint_refs,
401 omega_refs,
402 pressure_field,
403 soundspeed_field,
404 alpha_av_refs,
405 storage.neigh_cache,
406 axyz_refs,
407 duint_refs);
408 }
409 node->evaluate();
410}
411template<class Tvec, template<class> class SPHKernel>
412void shammodels::sph::modules::UpdateDerivs<Tvec, SPHKernel>::update_derivs_cd10(VaryingCD10 cfg) {
413 StackEntry stack_loc{};
414
415 using namespace shamrock;
416 using namespace shamrock::patch;
417
418 PatchDataLayerLayout &pdl = scheduler().pdl_old();
419
420 const u32 ixyz = pdl.get_field_idx<Tvec>("xyz");
421 const u32 ivxyz = pdl.get_field_idx<Tvec>("vxyz");
422 const u32 iaxyz = pdl.get_field_idx<Tvec>("axyz");
423 const u32 iuint = pdl.get_field_idx<Tscal>("uint");
424 const u32 iduint = pdl.get_field_idx<Tscal>("duint");
425 const u32 ihpart = pdl.get_field_idx<Tscal>("hpart");
426
428 = shambase::get_check_ref(storage.ghost_layout.get());
429 u32 ihpart_interf = ghost_layout.get_field_idx<Tscal>("hpart");
430 u32 iuint_interf = ghost_layout.get_field_idx<Tscal>("uint");
431 u32 ivxyz_interf = ghost_layout.get_field_idx<Tvec>("vxyz");
432 u32 iomega_interf = ghost_layout.get_field_idx<Tscal>("omega");
433
434 auto &merged_xyzh = storage.merged_xyzh.get();
436 shambase::DistributedData<PatchDataLayer> &mpdats = storage.merged_patchdata_ghost.get();
437
438 auto &part_counts = storage.part_counts;
439 auto &part_counts_with_ghost = storage.part_counts_with_ghost;
440 auto &xyz_refs = storage.positions_with_ghosts;
441 auto &pressure_field = storage.pressure;
442 auto &soundspeed_field = storage.soundspeed;
443
444 std::shared_ptr<shamrock::solvergraph::FieldRefs<Tscal>> uint_refs
445 = std::make_shared<shamrock::solvergraph::FieldRefs<Tscal>>("uint", "u");
446 {
447 shambase::get_check_ref(uint_refs).set_refs(
448 mpdats.map<std::reference_wrapper<PatchDataField<Tscal>>>(
449 [&](u64 id, shamrock::patch::PatchDataLayer &mpdat) {
450 return std::ref(mpdat.get_field<Tscal>(iuint_interf));
451 }));
452 }
453
454 std::shared_ptr<shamrock::solvergraph::FieldRefs<Tvec>> vxyz_refs
455 = std::make_shared<shamrock::solvergraph::FieldRefs<Tvec>>("vxyz", "v");
456 {
457 shambase::get_check_ref(vxyz_refs).set_refs(
458 mpdats.map<std::reference_wrapper<PatchDataField<Tvec>>>(
459 [&](u64 id, shamrock::patch::PatchDataLayer &mpdat) {
460 return std::ref(mpdat.get_field<Tvec>(ivxyz_interf));
461 }));
462 }
463
464 std::shared_ptr<shamrock::solvergraph::FieldRefs<Tscal>> hpart_refs
465 = std::make_shared<shamrock::solvergraph::FieldRefs<Tscal>>("hpart", "h");
466 { // if was just reset before this call
467 shambase::get_check_ref(hpart_refs)
468 .set_refs(mpdats.map<std::reference_wrapper<PatchDataField<Tscal>>>(
469 [&](u64 id, shamrock::patch::PatchDataLayer &mpdat) {
470 return std::ref(mpdat.get_field<Tscal>(ihpart_interf));
471 }));
472 }
473
474 std::shared_ptr<shamrock::solvergraph::FieldRefs<Tscal>> omega_refs
475 = std::make_shared<shamrock::solvergraph::FieldRefs<Tscal>>("omega", "omega");
476 {
477 shambase::get_check_ref(omega_refs)
478 .set_refs(mpdats.map<std::reference_wrapper<PatchDataField<Tscal>>>(
479 [&](u64 id, shamrock::patch::PatchDataLayer &mpdat) {
480 return std::ref(mpdat.get_field<Tscal>(iomega_interf));
481 }));
482 }
483
484 std::shared_ptr<shamrock::solvergraph::FieldRefs<Tscal>> alpha_av_refs
485 = std::make_shared<shamrock::solvergraph::FieldRefs<Tscal>>("alpha_av", "alpha_av");
486 {
488 scheduler().for_each_patchdata_nonempty([&](Patch cur_p, PatchDataLayer &pdat) {
489 refs.add_obj(
490 cur_p.id_patch, std::ref(storage.alpha_av_ghost.get().get(cur_p.id_patch)));
491 });
492 shambase::get_check_ref(alpha_av_refs).set_refs(refs);
493 }
494
495 shamrock::solvergraph::SolverGraph &solver_graph = storage.solver_graph;
496
497 auto axyz_refs = solver_graph.get_edge_ptr<shamrock::solvergraph::FieldRefs<Tvec>>("axyz");
498 auto duint_refs = solver_graph.get_edge_ptr<shamrock::solvergraph::FieldRefs<Tscal>>("duint");
499 auto gpart_mass
500 = solver_graph.get_edge_ptr<shamrock::solvergraph::ScalarEdge<Tscal>>("gpart_mass");
501
502 std::shared_ptr<shamrock::solvergraph::ScalarEdge<Tscal>> alpha_u
503 = std::make_shared<shamrock::solvergraph::ScalarEdge<Tscal>>("alpha_u", "alpha_u");
504 {
505 shambase::get_check_ref(alpha_u).value = cfg.alpha_u;
506 }
507 std::shared_ptr<shamrock::solvergraph::ScalarEdge<Tscal>> beta_AV
508 = std::make_shared<shamrock::solvergraph::ScalarEdge<Tscal>>("beta_AV", "beta_AV");
509 {
510 shambase::get_check_ref(beta_AV).value = cfg.beta_AV;
511 }
512
513 std::shared_ptr<NodeUpdateDerivsVaryingAlphaAV<Tvec, SPHKernel>> node
514 = std::make_shared<NodeUpdateDerivsVaryingAlphaAV<Tvec, SPHKernel>>();
515 {
516 node->set_edges(
517 gpart_mass,
518 alpha_u,
519 beta_AV,
520 part_counts,
521 part_counts_with_ghost,
522 xyz_refs,
523 hpart_refs,
524 vxyz_refs,
525 uint_refs,
526 omega_refs,
527 pressure_field,
528 soundspeed_field,
529 alpha_av_refs,
530 storage.neigh_cache,
531 axyz_refs,
532 duint_refs);
533 }
534 node->evaluate();
535}
536
537template<class Tvec, template<class> class SPHKernel>
538void shammodels::sph::modules::UpdateDerivs<Tvec, SPHKernel>::update_derivs_disc_visco(
539 ConstantDisc cfg) {
540 StackEntry stack_loc{};
541
542 using namespace shamrock;
543 using namespace shamrock::patch;
544
545 PatchDataLayerLayout &pdl = scheduler().pdl_old();
546
547 const u32 ixyz = pdl.get_field_idx<Tvec>("xyz");
548 const u32 ivxyz = pdl.get_field_idx<Tvec>("vxyz");
549 const u32 iaxyz = pdl.get_field_idx<Tvec>("axyz");
550 const u32 iuint = pdl.get_field_idx<Tscal>("uint");
551 const u32 iduint = pdl.get_field_idx<Tscal>("duint");
552 const u32 ihpart = pdl.get_field_idx<Tscal>("hpart");
553
555 = shambase::get_check_ref(storage.ghost_layout.get());
556 u32 ihpart_interf = ghost_layout.get_field_idx<Tscal>("hpart");
557 u32 iuint_interf = ghost_layout.get_field_idx<Tscal>("uint");
558 u32 ivxyz_interf = ghost_layout.get_field_idx<Tvec>("vxyz");
559 u32 iomega_interf = ghost_layout.get_field_idx<Tscal>("omega");
560
561 auto &merged_xyzh = storage.merged_xyzh.get();
563 shambase::DistributedData<PatchDataLayer> &mpdats = storage.merged_patchdata_ghost.get();
564
565 scheduler().for_each_patchdata_nonempty([&](Patch cur_p, PatchDataLayer &pdat) {
566 PatchDataLayer &mpdat = mpdats.get(cur_p.id_patch);
567
569 = merged_xyzh.get(cur_p.id_patch).template get_field_buf_ref<Tvec>(0);
570 sham::DeviceBuffer<Tvec> &buf_axyz = pdat.get_field_buf_ref<Tvec>(iaxyz);
571 sham::DeviceBuffer<Tscal> &buf_duint = pdat.get_field_buf_ref<Tscal>(iduint);
572 sham::DeviceBuffer<Tvec> &buf_vxyz = mpdat.get_field_buf_ref<Tvec>(ivxyz_interf);
573 sham::DeviceBuffer<Tscal> &buf_hpart = mpdat.get_field_buf_ref<Tscal>(ihpart_interf);
574 sham::DeviceBuffer<Tscal> &buf_omega = mpdat.get_field_buf_ref<Tscal>(iomega_interf);
575 sham::DeviceBuffer<Tscal> &buf_uint = mpdat.get_field_buf_ref<Tscal>(iuint_interf);
576 sham::DeviceBuffer<Tscal> &buf_pressure
577 = shambase::get_check_ref(storage.pressure).get_field(cur_p.id_patch).get_buf();
579 = shambase::get_check_ref(storage.soundspeed).get_field(cur_p.id_patch).get_buf();
580
581 sycl::range range_npart{pdat.get_obj_cnt()};
582
583 tree::ObjectCache &pcache
584 = shambase::get_check_ref(storage.neigh_cache).get_cache(cur_p.id_patch);
585
587
588 sham::DeviceQueue &q = shamsys::instance::get_compute_scheduler().get_queue();
589 sham::EventList depends_list;
590
591 auto xyz = buf_xyz.get_read_access(depends_list);
592 auto axyz = buf_axyz.get_write_access(depends_list);
593 auto du = buf_duint.get_write_access(depends_list);
594 auto vxyz = buf_vxyz.get_read_access(depends_list);
595 auto hpart = buf_hpart.get_read_access(depends_list);
596 auto omega = buf_omega.get_read_access(depends_list);
597 auto u = buf_uint.get_read_access(depends_list);
598 auto pressure = buf_pressure.get_read_access(depends_list);
599 auto cs = buf_cs.get_read_access(depends_list);
600 auto ploop_ptrs = pcache.get_read_access(depends_list);
601
602 auto e = q.submit(depends_list, [&](sycl::handler &cgh) {
603 const Tscal pmass = solver_config.gpart_mass;
604 const Tscal alpha_AV = cfg.alpha_AV;
605 const Tscal alpha_u = cfg.alpha_u;
606 const Tscal beta_AV = cfg.beta_AV;
607
608 shamlog_debug_sycl_ln("deriv kernel", "alpha_AV :", alpha_AV);
609 shamlog_debug_sycl_ln("deriv kernel", "alpha_u :", alpha_u);
610 shamlog_debug_sycl_ln("deriv kernel", "beta_AV :", beta_AV);
611
612 // tree::ObjectIterator particle_looper(tree,cgh);
613
614 // tree::LeafCacheObjectIterator
615 // particle_looper(tree,*xyz_cell_id,leaf_cache,cgh);
616
617 tree::ObjectCacheIterator particle_looper(ploop_ptrs);
618
619 // sycl::accessor hmax_tree{tree_field_hmax, cgh, sycl::read_only};
620
621 // sycl::stream out {4096,1024,cgh};
622
623 constexpr Tscal Rker2 = Kernel::Rkern * Kernel::Rkern;
624
625 shambase::parallel_for(cgh, pdat.get_obj_cnt(), "compute force disc", [=](u64 gid) {
626 u32 id_a = (u32) gid;
627
628 using namespace shamrock::sph;
629
630 Tvec sum_axyz = {0, 0, 0};
631 Tscal sum_du_a = 0;
632
633 Tscal h_a = hpart[id_a];
634 Tvec xyz_a = xyz[id_a];
635 Tvec vxyz_a = vxyz[id_a];
636 Tscal P_a = pressure[id_a];
637 Tscal cs_a = cs[id_a];
638 Tscal omega_a = omega[id_a];
639 const Tscal u_a = u[id_a];
640
641 Tscal rho_a = rho_h(pmass, h_a, Kernel::hfactd);
642 Tscal rho_a_sq = rho_a * rho_a;
643 Tscal rho_a_inv = 1. / rho_a;
644
645 // f32 P_a = cs * cs * rho_a;
646
647 Tscal omega_a_rho_a_inv = 1 / (omega_a * rho_a);
648
649 Tvec force_pressure{0, 0, 0};
650 Tscal tmpdU_pressure = 0;
651
652 particle_looper.for_each_object(id_a, [&](u32 id_b) {
653 // compute only omega_a
654 Tvec dr = xyz_a - xyz[id_b];
655 Tscal rab2 = sycl::dot(dr, dr);
656 Tscal h_b = hpart[id_b];
657
658 if (rab2 > h_a * h_a * Rker2 && rab2 > h_b * h_b * Rker2) {
659 return;
660 }
661
662 Tvec vxyz_b = vxyz[id_b];
663 const Tscal u_b = u[id_b];
664 Tscal P_b = pressure[id_b];
665 Tscal omega_b = omega[id_b];
666 Tscal cs_b = cs[id_b];
667
668 Tscal rab = sycl::sqrt(rab2);
669
670 Tscal rho_b = rho_h(pmass, h_b, Kernel::hfactd);
671 const Tscal alpha_a = alpha_AV;
672 const Tscal alpha_b = alpha_AV;
673 Tscal Fab_a = Kernel::dW_3d(rab, h_a);
674 Tscal Fab_b = Kernel::dW_3d(rab, h_b);
675
676 Tvec v_ab = vxyz_a - vxyz_b;
677
678 Tvec r_ab_unit = dr * sham::inv_sat_positive(rab);
679
680 // f32 P_b = cs * cs * rho_b;
681 Tscal v_ab_r_ab = sycl::dot(v_ab, r_ab_unit);
682 Tscal abs_v_ab_r_ab = sycl::fabs(v_ab_r_ab);
683
684 Tscal vsig_a = alpha_a * cs_a + beta_AV * abs_v_ab_r_ab;
685 Tscal vsig_b = alpha_b * cs_b + beta_AV * abs_v_ab_r_ab;
686
687 Tscal vsig_u = shamrock::sph::vsig_u(P_a, P_b, rho_a, rho_b);
688
689 Tscal qa_ab = shamrock::sph::q_av_disc(
690 rho_a, h_a, rab, alpha_a, cs_a, vsig_a, v_ab_r_ab);
691 Tscal qb_ab = shamrock::sph::q_av_disc(
692 rho_b, h_b, rab, alpha_b, cs_b, vsig_b, v_ab_r_ab);
693
694 add_to_derivs_sph_artif_visco_cond(
695 pmass,
696 rho_a_sq,
697 omega_a_rho_a_inv,
698 rho_a_inv,
699 rho_b,
700 omega_a,
701 omega_b,
702 Fab_a,
703 Fab_b,
704 u_a,
705 u_b,
706 P_a,
707 P_b,
708 alpha_u,
709 v_ab,
710 r_ab_unit,
711 vsig_u,
712 qa_ab,
713 qb_ab,
714
715 force_pressure,
716 tmpdU_pressure);
717 });
718
719 axyz[id_a] = force_pressure;
720 du[id_a] = tmpdU_pressure;
721 });
722 });
723
724 buf_xyz.complete_event_state(e);
725 buf_axyz.complete_event_state(e);
726 buf_duint.complete_event_state(e);
727 buf_vxyz.complete_event_state(e);
728 buf_hpart.complete_event_state(e);
729 buf_omega.complete_event_state(e);
730 buf_uint.complete_event_state(e);
731 buf_pressure.complete_event_state(e);
732 buf_cs.complete_event_state(e);
733
734 sham::EventList resulting_events;
735 resulting_events.add_event(e);
736 pcache.complete_event_state(resulting_events);
737 });
738}
739
740template<class Tvec, template<class> class SPHKernel>
741void shammodels::sph::modules::UpdateDerivs<Tvec, SPHKernel>::update_derivs_MHD(IdealMHD cfg) {
742 StackEntry stack_loc{};
743
744 using namespace shamrock;
745 using namespace shamrock::patch;
746
747 PatchDataLayerLayout &pdl = scheduler().pdl_old();
748
749 const u32 ixyz = pdl.get_field_idx<Tvec>("xyz");
750 const u32 ivxyz = pdl.get_field_idx<Tvec>("vxyz");
751 const u32 iaxyz = pdl.get_field_idx<Tvec>("axyz");
752 const u32 iuint = pdl.get_field_idx<Tscal>("uint");
753 const u32 iduint = pdl.get_field_idx<Tscal>("duint");
754 const u32 ihpart = pdl.get_field_idx<Tscal>("hpart");
755 const u32 iB_on_rho = pdl.get_field_idx<Tvec>("B/rho");
756 const u32 idB_on_rho = pdl.get_field_idx<Tvec>("dB/rho");
757 const u32 ipsi_on_ch = pdl.get_field_idx<Tscal>("psi/ch");
758 const u32 idpsi_on_ch = pdl.get_field_idx<Tscal>("dpsi/ch");
759 const u32 idrho_dt = pdl.get_field_idx<Tscal>("drho/dt");
760
761 bool do_MHD_debug = solver_config.do_MHD_debug();
762 const u32 imag_pressure = (do_MHD_debug) ? pdl.get_field_idx<Tvec>("mag_pressure") : -1;
763 const u32 imag_tension = (do_MHD_debug) ? pdl.get_field_idx<Tvec>("mag_tension") : -1;
764 const u32 igas_pressure = (do_MHD_debug) ? pdl.get_field_idx<Tvec>("gas_pressure") : -1;
765 const u32 itensile_corr = (do_MHD_debug) ? pdl.get_field_idx<Tvec>("tensile_corr") : -1;
766 const u32 ipsi_propag = (do_MHD_debug) ? pdl.get_field_idx<Tscal>("psi_propag") : -1;
767 const u32 ipsi_diff = (do_MHD_debug) ? pdl.get_field_idx<Tscal>("psi_diff") : -1;
768 const u32 ipsi_cons = (do_MHD_debug) ? pdl.get_field_idx<Tscal>("psi_cons") : -1;
769 const u32 iu_mhd = (do_MHD_debug) ? pdl.get_field_idx<Tscal>("u_mhd") : -1;
770
771 // Tscal mu_0 = 1.;
772 Tscal const mu_0 = solver_config.get_constant_mu_0();
773
775 = shambase::get_check_ref(storage.ghost_layout.get());
776 u32 ihpart_interf = ghost_layout.get_field_idx<Tscal>("hpart");
777 u32 iuint_interf = ghost_layout.get_field_idx<Tscal>("uint");
778 u32 ivxyz_interf = ghost_layout.get_field_idx<Tvec>("vxyz");
779 u32 iomega_interf = ghost_layout.get_field_idx<Tscal>("omega");
780 u32 iB_on_rho_interf = ghost_layout.get_field_idx<Tvec>("B/rho");
781 u32 ipsi_on_ch_interf = ghost_layout.get_field_idx<Tscal>("psi/ch");
782
783 // logger::raw_ln("charged the ghost fields.");
784
785 auto &merged_xyzh = storage.merged_xyzh.get();
787 shambase::DistributedData<PatchDataLayer> &mpdats = storage.merged_patchdata_ghost.get();
788
789 scheduler().for_each_patchdata_nonempty([&](Patch cur_p, PatchDataLayer &pdat) {
790 PatchDataLayer &mpdat = mpdats.get(cur_p.id_patch);
791
793 = merged_xyzh.get(cur_p.id_patch).template get_field_buf_ref<Tvec>(0);
794 sham::DeviceBuffer<Tvec> &buf_axyz = pdat.get_field_buf_ref<Tvec>(iaxyz);
795 sham::DeviceBuffer<Tscal> &buf_duint = pdat.get_field_buf_ref<Tscal>(iduint);
796 sham::DeviceBuffer<Tvec> &buf_vxyz = mpdat.get_field_buf_ref<Tvec>(ivxyz_interf);
797 sham::DeviceBuffer<Tscal> &buf_hpart = mpdat.get_field_buf_ref<Tscal>(ihpart_interf);
798 sham::DeviceBuffer<Tscal> &buf_omega = mpdat.get_field_buf_ref<Tscal>(iomega_interf);
799 sham::DeviceBuffer<Tscal> &buf_uint = mpdat.get_field_buf_ref<Tscal>(iuint_interf);
800 sham::DeviceBuffer<Tscal> &buf_pressure
801 = shambase::get_check_ref(storage.pressure).get_field(cur_p.id_patch).get_buf();
803 = shambase::get_check_ref(storage.soundspeed).get_field(cur_p.id_patch).get_buf();
804
805 sham::DeviceBuffer<Tvec> &buf_dB_on_rho = pdat.get_field_buf_ref<Tvec>(idB_on_rho);
806 sham::DeviceBuffer<Tscal> &buf_dpsi_on_ch = pdat.get_field_buf_ref<Tscal>(idpsi_on_ch);
807 sham::DeviceBuffer<Tscal> &buf_drho_dt = pdat.get_field_buf_ref<Tscal>(idrho_dt);
808 // logger::raw_ln("charged dB dpsi");
809
810 sham::DeviceBuffer<Tvec> &buf_B_on_rho = mpdat.get_field_buf_ref<Tvec>(iB_on_rho_interf);
811 sham::DeviceBuffer<Tscal> &buf_psi_on_ch
812 = mpdat.get_field_buf_ref<Tscal>(ipsi_on_ch_interf);
813
814 // logger::raw_ln("charged B psi");
815 // ADD curlBBBBBBBBB
816
817 sycl::range range_npart{pdat.get_obj_cnt()};
818
819 tree::ObjectCache &pcache
820 = shambase::get_check_ref(storage.neigh_cache).get_cache(cur_p.id_patch);
821
823
824 sham::DeviceQueue &q = shamsys::instance::get_compute_scheduler().get_queue();
825 sham::EventList depends_list;
826
827 auto xyz = buf_xyz.get_read_access(depends_list);
828 auto axyz = buf_axyz.get_write_access(depends_list);
829 auto du = buf_duint.get_write_access(depends_list);
830 auto vxyz = buf_vxyz.get_read_access(depends_list);
831 auto hpart = buf_hpart.get_read_access(depends_list);
832 auto omega = buf_omega.get_read_access(depends_list);
833 auto u = buf_uint.get_read_access(depends_list);
834 auto pressure = buf_pressure.get_read_access(depends_list);
835 auto cs = buf_cs.get_read_access(depends_list);
836 auto B_on_rho = buf_B_on_rho.get_read_access(depends_list);
837 auto psi_on_ch = buf_psi_on_ch.get_read_access(depends_list);
838 auto dB_on_rho = buf_dB_on_rho.get_write_access(depends_list);
839 auto dpsi_on_ch = buf_dpsi_on_ch.get_write_access(depends_list);
840 auto drho_dt = buf_drho_dt.get_write_access(depends_list);
841
842 Tvec *mag_pressure
843 = (do_MHD_debug)
844 ? pdat.get_field_buf_ref<Tvec>(imag_pressure).get_write_access(depends_list)
845 : nullptr;
846 Tvec *mag_tension
847 = (do_MHD_debug)
848 ? pdat.get_field_buf_ref<Tvec>(imag_tension).get_write_access(depends_list)
849 : nullptr;
850 Tvec *gas_pressure
851 = (do_MHD_debug)
852 ? pdat.get_field_buf_ref<Tvec>(igas_pressure).get_write_access(depends_list)
853 : nullptr;
854 Tvec *tensile_corr
855 = (do_MHD_debug)
856 ? pdat.get_field_buf_ref<Tvec>(itensile_corr).get_write_access(depends_list)
857 : nullptr;
858
859 Tscal *psi_propag
860 = (do_MHD_debug)
861 ? pdat.get_field_buf_ref<Tscal>(ipsi_propag).get_write_access(depends_list)
862 : nullptr;
863 Tscal *psi_diff
864 = (do_MHD_debug)
865 ? pdat.get_field_buf_ref<Tscal>(ipsi_diff).get_write_access(depends_list)
866 : nullptr;
867 Tscal *psi_cons
868 = (do_MHD_debug)
869 ? pdat.get_field_buf_ref<Tscal>(ipsi_cons).get_write_access(depends_list)
870 : nullptr;
871
872 Tscal *u_mhd = (do_MHD_debug)
873 ? pdat.get_field_buf_ref<Tscal>(iu_mhd).get_write_access(depends_list)
874 : nullptr;
875
876 auto ploop_ptrs = pcache.get_read_access(depends_list);
877
878 auto e = q.submit(depends_list, [&](sycl::handler &cgh) {
879 const Tscal pmass = solver_config.gpart_mass;
880 const Tscal sigma_mhd = cfg.sigma_mhd;
881 const Tscal alpha_u = cfg.alpha_u;
882
883 shamlog_debug_ln("@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@", "");
884 shamlog_debug_sycl_ln("deriv kernel", "sigma_mhd :", sigma_mhd);
885 shamlog_debug_sycl_ln("deriv kernel", "alpha_u :", alpha_u);
886 shamlog_debug_ln("@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@", "");
887
888 tree::ObjectCacheIterator particle_looper(ploop_ptrs);
889
890 constexpr Tscal Rker2 = Kernel::Rkern * Kernel::Rkern;
891
892 shambase::parallel_for(cgh, pdat.get_obj_cnt(), "compute MHD", [=](u64 gid) {
893 u32 id_a = (u32) gid;
894
895 using namespace shamrock::sph;
896
897 Tvec sum_axyz = {0, 0, 0};
898 Tscal sum_du_a = 0;
899
900 Tscal h_a = hpart[id_a];
901 Tvec xyz_a = xyz[id_a];
902 Tvec vxyz_a = vxyz[id_a];
903 Tscal P_a = pressure[id_a];
904 Tscal cs_a = cs[id_a];
905 Tscal omega_a = omega[id_a];
906 const Tscal u_a = u[id_a];
907
908 Tscal rho_a = rho_h(pmass, h_a, Kernel::hfactd);
909 Tscal rho_a_sq = rho_a * rho_a;
910 Tscal rho_a_inv = 1. / rho_a;
911
912 Tvec B_a = B_on_rho[id_a] * rho_a;
913 Tscal v_alfven_a = sycl::sqrt(sycl::dot(B_a, B_a) / (mu_0 * rho_a));
914 Tscal v_shock_a = sycl::sqrt(cs_a * cs_a + v_alfven_a * v_alfven_a);
915 Tscal psi_a = psi_on_ch[id_a] * v_shock_a;
916
917 Tscal omega_a_rho_a_inv = 1 / (omega_a * rho_a);
918
919 Tvec force_pressure{0, 0, 0};
920 Tscal tmpdU_pressure = 0;
921 Tvec magnetic_eq{0, 0, 0};
922 Tscal psi_eq = 0;
923 Tscal drho_eq = 0;
924
925 Tvec mag_pressure_term{0, 0, 0};
926 Tvec mag_tension_term{0, 0, 0};
927 Tvec gas_pressure_term{0, 0, 0};
928 Tvec tensile_corr_term{0, 0, 0};
929
930 Tscal psi_propag_term = 0;
931 Tscal psi_diff_term = 0;
932 Tscal psi_cons_term = 0;
933
934 Tscal u_mhd_term = 0;
935
936 particle_looper.for_each_object(id_a, [&](u32 id_b) {
937 // compute only omega_a
938 Tvec dr = xyz_a - xyz[id_b];
939 Tscal rab2 = sycl::dot(dr, dr);
940 Tscal h_b = hpart[id_b];
941
942 if (rab2 > h_a * h_a * Rker2 && rab2 > h_b * h_b * Rker2) {
943 return;
944 }
945
946 Tvec vxyz_b = vxyz[id_b];
947 const Tscal u_b = u[id_b];
948 Tscal P_b = pressure[id_b];
949 Tscal omega_b = omega[id_b];
950 Tscal cs_b = cs[id_b];
951
952 Tscal rab = sycl::sqrt(rab2);
953
954 Tscal rho_b = rho_h(pmass, h_b, Kernel::hfactd);
955 Tvec B_b = B_on_rho[id_b] * rho_b;
956 Tscal v_alfven_b = sycl::sqrt(sycl::dot(B_b, B_b) / (mu_0 * rho_b));
957 Tscal v_shock_b = sycl::sqrt(cs_b * cs_b + v_alfven_b * v_alfven_b);
958 Tscal psi_b = psi_on_ch[id_b] * v_shock_b;
959 // const Tscal alpha_a = alpha_AV;
960 // const Tscal alpha_b = alpha_AV;
961 Tscal Fab_a = Kernel::dW_3d(rab, h_a);
962 Tscal Fab_b = Kernel::dW_3d(rab, h_b);
963
964 // Tscal sigma_mhd = 0.3;
965 shamrock::sph::mhd::add_to_derivs_spmhd<Kernel, Tvec, Tscal>(
966 pmass,
967 dr,
968 rab,
969 rho_a,
970 rho_a_sq,
971 omega_a_rho_a_inv,
972 rho_a_inv,
973 rho_b,
974 omega_a,
975 omega_b,
976 Fab_a,
977 Fab_b,
978 vxyz_a,
979 vxyz_b,
980 u_a,
981 u_b,
982 P_a,
983 P_b,
984 cs_a,
985 cs_b,
986 h_a,
987 h_b,
988
989 alpha_u,
990
991 B_a,
992 B_b,
993
994 psi_a,
995 psi_b,
996
997 mu_0,
998 sigma_mhd,
999
1000 force_pressure,
1001 tmpdU_pressure,
1002 magnetic_eq,
1003 psi_eq,
1004 drho_eq,
1005 mag_pressure_term,
1006 mag_tension_term,
1007 gas_pressure_term,
1008 tensile_corr_term,
1009
1010 psi_propag_term,
1011 psi_diff_term,
1012 psi_cons_term,
1013 u_mhd_term);
1014 });
1015
1016 axyz[id_a] = force_pressure;
1017 du[id_a] = tmpdU_pressure;
1018 dB_on_rho[id_a] = magnetic_eq;
1019 dpsi_on_ch[id_a] = psi_eq - psi_a / h_a;
1020 drho_dt[id_a] = drho_eq;
1021
1022 if (do_MHD_debug) {
1023 mag_pressure[id_a] = mag_pressure_term;
1024 mag_tension[id_a] = mag_tension_term;
1025 gas_pressure[id_a] = gas_pressure_term;
1026 tensile_corr[id_a] = tensile_corr_term;
1027
1028 psi_propag[id_a] = psi_propag_term;
1029 psi_diff[id_a] = psi_diff_term;
1030 psi_cons[id_a] = -psi_a / h_a;
1031
1032 u_mhd[id_a] = u_mhd_term;
1033 }
1034 });
1035 });
1036
1037 buf_xyz.complete_event_state(e);
1038 buf_axyz.complete_event_state(e);
1039 buf_duint.complete_event_state(e);
1040 buf_vxyz.complete_event_state(e);
1041 buf_hpart.complete_event_state(e);
1042 buf_omega.complete_event_state(e);
1043 buf_uint.complete_event_state(e);
1044 buf_pressure.complete_event_state(e);
1045 buf_cs.complete_event_state(e);
1046 buf_B_on_rho.complete_event_state(e);
1047 buf_psi_on_ch.complete_event_state(e);
1048 buf_dB_on_rho.complete_event_state(e);
1049 buf_dpsi_on_ch.complete_event_state(e);
1050 buf_drho_dt.complete_event_state(e);
1051
1052 if (do_MHD_debug) {
1053 pdat.get_field_buf_ref<Tvec>(imag_pressure).complete_event_state(e);
1054 pdat.get_field_buf_ref<Tvec>(imag_tension).complete_event_state(e);
1055 pdat.get_field_buf_ref<Tvec>(igas_pressure).complete_event_state(e);
1056 pdat.get_field_buf_ref<Tvec>(itensile_corr).complete_event_state(e);
1057
1058 pdat.get_field_buf_ref<Tscal>(ipsi_propag).complete_event_state(e);
1059 pdat.get_field_buf_ref<Tscal>(ipsi_diff).complete_event_state(e);
1060 pdat.get_field_buf_ref<Tscal>(ipsi_cons).complete_event_state(e);
1061
1062 pdat.get_field_buf_ref<Tscal>(iu_mhd).complete_event_state(e);
1063 }
1064
1065 sham::EventList resulting_events;
1066 resulting_events.add_event(e);
1067 pcache.complete_event_state(resulting_events);
1068 });
1069}
1070
1071template<class Tvec, template<class> class SPHKernel>
1072void shammodels::sph::modules::UpdateDerivs<Tvec, SPHKernel>::update_derivs_dust_monofluid_tvi_Sj(
1073 DustConfig cfg, Tscal dt_hydro) {
1074
1075 using MonofluidTVI = typename DustConfig::MonofluidTVI;
1076
1077 StackEntry stack_loc{};
1078
1079 using namespace shamrock;
1080 using namespace shamrock::patch;
1081
1082 PatchDataLayerLayout &pdl = scheduler().pdl_old();
1083
1084 const u32 ixyz = pdl.get_field_idx<Tvec>("xyz");
1085 const u32 ivxyz = pdl.get_field_idx<Tvec>("vxyz");
1086 const u32 iaxyz = pdl.get_field_idx<Tvec>("axyz");
1087 const u32 ihpart = pdl.get_field_idx<Tscal>("hpart");
1088 const u32 is_j = pdl.get_field_idx<Tscal>("s_j");
1089 const u32 ids_j_dt = pdl.get_field_idx<Tscal>("ds_j_dt");
1090
1092 = shambase::get_check_ref(storage.ghost_layout.get());
1093 u32 ihpart_interf = ghost_layout.get_field_idx<Tscal>("hpart");
1094 u32 ivxyz_interf = ghost_layout.get_field_idx<Tvec>("vxyz");
1095 u32 iomega_interf = ghost_layout.get_field_idx<Tscal>("omega");
1096 u32 is_j_interf = ghost_layout.get_field_idx<Tscal>("s_j");
1097
1098 u32 ndust = cfg.get_dust_nvar();
1099
1100 auto &merged_xyzh = storage.merged_xyzh.get();
1102 shambase::DistributedData<PatchDataLayer> &mpdats = storage.merged_patchdata_ghost.get();
1103
1104 auto &part_counts = storage.part_counts;
1105 auto &part_counts_with_ghost = storage.part_counts_with_ghost;
1106 auto &xyz_refs = storage.positions_with_ghosts;
1107 auto &pressure_field = storage.pressure;
1108
1109 shamrock::solvergraph::SolverGraph &solver_graph = storage.solver_graph;
1110 auto gpart_mass
1111 = solver_graph.get_edge_ptr<shamrock::solvergraph::ScalarEdge<Tscal>>("gpart_mass");
1112
1113 std::shared_ptr<shamrock::solvergraph::FieldRefs<Tvec>> vxyz_refs
1114 = std::make_shared<shamrock::solvergraph::FieldRefs<Tvec>>("vxyz", "v");
1115 {
1116 shambase::get_check_ref(vxyz_refs).set_refs(
1117 mpdats.map<std::reference_wrapper<PatchDataField<Tvec>>>(
1118 [&](u64 id, shamrock::patch::PatchDataLayer &mpdat) {
1119 return std::ref(mpdat.get_field<Tvec>(ivxyz_interf));
1120 }));
1121 }
1122
1123 std::shared_ptr<shamrock::solvergraph::FieldRefs<Tscal>> hpart_refs
1124 = std::make_shared<shamrock::solvergraph::FieldRefs<Tscal>>("hpart", "h");
1125 { // if was just reset before this call
1126 shambase::get_check_ref(hpart_refs)
1127 .set_refs(mpdats.map<std::reference_wrapper<PatchDataField<Tscal>>>(
1128 [&](u64 id, shamrock::patch::PatchDataLayer &mpdat) {
1129 return std::ref(mpdat.get_field<Tscal>(ihpart_interf));
1130 }));
1131 }
1132
1133 std::shared_ptr<shamrock::solvergraph::FieldRefs<Tscal>> omega_refs
1134 = std::make_shared<shamrock::solvergraph::FieldRefs<Tscal>>("omega", "omega");
1135 {
1136 shambase::get_check_ref(omega_refs)
1137 .set_refs(mpdats.map<std::reference_wrapper<PatchDataField<Tscal>>>(
1138 [&](u64 id, shamrock::patch::PatchDataLayer &mpdat) {
1139 return std::ref(mpdat.get_field<Tscal>(iomega_interf));
1140 }));
1141 }
1142
1143 // s_j_interf
1144 std::shared_ptr<shamrock::solvergraph::FieldRefs<Tscal>> s_j_refs
1145 = std::make_shared<shamrock::solvergraph::FieldRefs<Tscal>>("s_j", "s_j");
1146 {
1147 shambase::get_check_ref(s_j_refs).set_refs(
1148 mpdats.map<std::reference_wrapper<PatchDataField<Tscal>>>(
1149 [&](u64 id, shamrock::patch::PatchDataLayer &mpdat) {
1150 return std::ref(mpdat.get_field<Tscal>(is_j_interf));
1151 }));
1152 }
1153
1154 std::shared_ptr<shamrock::solvergraph::Field<Tscal>> t_j_field
1155 = std::make_shared<shamrock::solvergraph::Field<Tscal>>(ndust, "t_j", "t_j");
1156
1157 using None = typename DustConfig::None;
1158 using ConstantStoppingTimes = typename DustConfig::ConstantStoppingTimes;
1159 using EpsteinDrag = typename DustConfig::EpsteinDrag;
1160
1161 if (std::holds_alternative<None>(cfg.dust_drag_mode)) {
1162
1163 throw "bro WTF";
1164
1165 } else if (
1166 ConstantStoppingTimes *cfg_drag = std::get_if<ConstantStoppingTimes>(&cfg.dust_drag_mode)) {
1167
1168 std::shared_ptr<shamrock::solvergraph::ScalarEdge<std::vector<Tscal>>> input_t_j
1169 = std::make_shared<shamrock::solvergraph::ScalarEdge<std::vector<Tscal>>>("", "");
1170 input_t_j->value = cfg_drag->stopping_times;
1171
1172 std::shared_ptr<SetDustStoppingTimeConstant<Tvec>> node_set_tj
1173 = std::make_shared<SetDustStoppingTimeConstant<Tvec>>(ndust);
1174 {
1175 node_set_tj->set_edges(input_t_j, part_counts_with_ghost, t_j_field);
1176 }
1177 node_set_tj->evaluate();
1178
1179 } else if (EpsteinDrag *cfg_drag = std::get_if<EpsteinDrag>(&cfg.dust_drag_mode)) {
1180
1181 std::shared_ptr<shamrock::solvergraph::ScalarEdge<Tscal>> input_gamma
1182 = std::make_shared<shamrock::solvergraph::ScalarEdge<Tscal>>("", "");
1183 input_gamma->value = cfg_drag->gamma;
1184
1185 std::shared_ptr<shamrock::solvergraph::ScalarEdge<std::vector<Tscal>>> input_sgrain_j
1186 = std::make_shared<shamrock::solvergraph::ScalarEdge<std::vector<Tscal>>>("", "");
1187 input_sgrain_j->value = cfg_drag->grains_sizes;
1188
1189 std::shared_ptr<shamrock::solvergraph::ScalarEdge<std::vector<Tscal>>> input_rho_grain_j
1190 = std::make_shared<shamrock::solvergraph::ScalarEdge<std::vector<Tscal>>>("", "");
1191 input_rho_grain_j->value = cfg_drag->grains_densities;
1192
1193 std::shared_ptr<SetDustStoppingTimeEpstein<Tvec, SPHKernel>> node_set_tj
1194 = std::make_shared<SetDustStoppingTimeEpstein<Tvec, SPHKernel>>(ndust);
1195 {
1196 node_set_tj->set_edges(
1197 gpart_mass,
1198 input_gamma,
1199 input_sgrain_j,
1200 input_rho_grain_j,
1201 part_counts_with_ghost,
1202 hpart_refs,
1203 storage.soundspeed,
1204 t_j_field);
1205 }
1206 node_set_tj->evaluate();
1207 }
1208
1209 std::shared_ptr<shamrock::solvergraph::Field<Tscal>> Ttilde_sj_field
1210 = std::make_shared<shamrock::solvergraph::Field<Tscal>>(ndust, "Ttilde_sj", "Ttilde_sj");
1211
1212 auto ds_j_dt_refs
1213 = solver_graph.get_edge_ptr<shamrock::solvergraph::FieldRefs<Tscal>>("ds_j_dt");
1214
1215 std::shared_ptr<ComputeDustTtilde<Tvec, SPHKernel>> node_tj
1216 = std::make_shared<ComputeDustTtilde<Tvec, SPHKernel>>(ndust);
1217 {
1218 node_tj->set_edges(
1219 gpart_mass, part_counts_with_ghost, hpart_refs, s_j_refs, t_j_field, Ttilde_sj_field);
1220 }
1221 node_tj->evaluate();
1222
1223 std::shared_ptr<NodeUpdateDerivsMonofluidTVI<Tvec, SPHKernel>> node
1224 = std::make_shared<NodeUpdateDerivsMonofluidTVI<Tvec, SPHKernel>>(ndust);
1225 {
1226 node->set_edges(
1227 gpart_mass,
1228 part_counts,
1229 part_counts_with_ghost,
1230 xyz_refs,
1231 hpart_refs,
1232 vxyz_refs,
1233 omega_refs,
1234 pressure_field,
1235 s_j_refs,
1236 Ttilde_sj_field,
1237 storage.neigh_cache,
1238 ds_j_dt_refs);
1239 }
1240 node->evaluate();
1241
1242 MonofluidTVI &cfg_monofluid_tvi
1243 = shambase::get_check_ref((std::get_if<MonofluidTVI>(&cfg.current_mode)));
1244
1245 if (cfg_monofluid_tvi.pure_diffusion_mode) {
1246 // reset accelerations & du/dt to 0
1247
1248 const u32 iaxyz = pdl.get_field_idx<Tvec>("axyz");
1249 const u32 iduint = pdl.get_field_idx<Tscal>("duint");
1250
1251 scheduler().for_each_patchdata_nonempty([&](Patch cur_p, PatchDataLayer &pdat) {
1252 pdat.get_field_buf_ref<Tvec>(iaxyz).fill({0, 0, 0});
1253 pdat.get_field_buf_ref<Tscal>(iduint).fill(0);
1254 });
1255 }
1256}
1257
1258using namespace shammath;
1262
Compute the dust combined stopping times Ttilde_sj for each dust species j see Hutchison 2018 eq 15.
constexpr const char * axyz
3-acceleration field
constexpr const char * vxyz
3-velocity field
constexpr const char * part_counts_with_ghost
Particle counts including ghosts.
constexpr const char * xyz
Position field (3D coordinates).
constexpr const char * part_counts
Particle counts per patch.
constexpr const char * pressure
Pressure P (derived from EOS).
constexpr const char * hpart
Smoothing length field.
constexpr const char * omega
Grad-h correction factor \Omega.
std::uint32_t u32
32 bit unsigned integer
std::uint64_t u64
64 bit unsigned integer
A buffer allocated in USM (Unified Shared Memory).
void complete_event_state(sycl::event e) const
Complete the event state of the buffer.
T * get_write_access(sham::EventList &depends_list, SourceLocation src_loc=SourceLocation{})
Get a read-write pointer to the buffer's data.
const T * get_read_access(sham::EventList &depends_list, SourceLocation src_loc=SourceLocation{}) const
Get a read-only pointer to the buffer's data.
A SYCL queue associated with a device and a context.
sycl::event submit(Fct &&fct)
Submits a kernel to the SYCL queue.
Class to manage a list of SYCL events.
Definition EventList.hpp:31
void add_event(sycl::event e)
Add an event to the list of events.
Definition EventList.hpp:87
Represents a collection of objects distributed across patches identified by a u64 id.
iterator add_obj(u64 id, T &&obj)
Adds a new object to the collection.
DistributedData< Tmap > map(std::function< Tmap(u64, T &)> map_func)
Apply a function to all objects in the collection and return a new collection containing the results.
T & get(u64 id)
Returns a reference to an object in the collection.
u32 get_field_idx(const std::string &field_name) const
Get the field id if matching name & type.
PatchDataLayer container class, the layout is described in patchdata_layout.
A graph container for managing solver nodes and edges with type-safe access.
std::shared_ptr< T > get_edge_ptr(const std::string &name)
Get a typed shared pointer to an edge by name.
T inv_sat_positive(T v, T minvsat=T{1e-9}, T satval=T{0.}) noexcept
inverse saturated (positive numbers only)
Definition math.hpp:841
T & get_check_ref(const std::unique_ptr< T > &ptr, SourceLocation loc=SourceLocation())
Takes a std::unique_ptr and returns a reference to the object it holds. It throws a std::runtime_erro...
Definition memory.hpp:110
void throw_unimplemented(SourceLocation loc=SourceLocation{})
Throw a std::runtime_error saying that the function is unimplemented.
namespace for math utility
Definition AABB.hpp:26
namespace for the main framework
Definition __init__.py:1
constexpr Tscal q_av(const Tscal &rho, const Tscal &vsig, const Tscal &v_scal_rhat)
phantom_2018 eq.40
Definition q_ab.hpp:37
file containing formulas for sphmhd forces, evolution of magnetic and divergence cleaning fields.
file containing formulas for sph forces
sph kernels
shambase::details::BasicStackEntry StackEntry
Alias for shambase::details::BasicStackEntry.
Patch object that contain generic patch information.
Definition Patch.hpp:33
u64 id_patch
unique key that identify the patch
Definition Patch.hpp:86