Shamrock 2025.10.0
Astrophysical Code
Loading...
Searching...
No Matches
ExternalForces.cpp
Go to the documentation of this file.
1// -------------------------------------------------------//
2//
3// SHAMROCK code for hydrodynamics
4// Copyright (c) 2021-2026 Timothée David--Cléris <tim.shamrock@proton.me>
5// SPDX-License-Identifier: CeCILL Free Software License Agreement v2.1
6// Shamrock is licensed under the CeCILL 2.1 License, see LICENSE for more information
7//
8// -------------------------------------------------------//
9
17
18#include "shambase/memory.hpp"
21#include "shamcomm/logs.hpp"
39
40namespace shambase {
41
42 template<class T>
43 std::shared_ptr<T> to_shared(T &&t) {
44 return std::make_shared<T>(std::forward<T>(t));
45 }
46} // namespace shambase
47
48template<class Tvec, template<class> class SPHKernel>
50
51 StackEntry stack_loc{};
52
53 sham::DeviceQueue &q = shamsys::instance::get_compute_scheduler().get_queue();
54
55 Tscal gpart_mass = solver_config.gpart_mass;
56
57 using namespace shamrock;
58 using namespace shamrock::patch;
59
60 PatchDataLayerLayout &pdl = scheduler().pdl_old();
61
62 const u32 iaxyz_ext = pdl.get_field_idx<Tvec>("axyz_ext");
63 modules::SinkParticlesUpdate<Tvec, SPHKernel> sink_update(context, solver_config, storage);
64
65 scheduler().for_each_patchdata_nonempty([&](Patch cur_p, PatchDataLayer &pdat) {
66 PatchDataField<Tvec> &field = pdat.get_field<Tvec>(iaxyz_ext);
67 field.field_raz();
68 });
69
70 sink_update.compute_sph_forces();
71
72 if (solver_config.ext_force_config.ext_forces.empty()) {
73 return;
74 }
75
76 auto field_xyz = shamrock::solvergraph::FieldRefs<Tvec>::make_shared("", "");
77
79 [&](shamrock::solvergraph::FieldRefs<Tvec> &field_xyz_edge) {
81 scheduler().for_each_patchdata_nonempty([&](const Patch p, PatchDataLayer &pdat) {
82 auto &field = pdat.get_field<Tvec>(0);
83 field_xyz_refs.add_obj(p.id_patch, std::ref(field));
84 });
85 field_xyz_edge.set_refs(field_xyz_refs);
86 });
87 set_field_xyz.set_edges(field_xyz);
88 set_field_xyz.evaluate();
89
90 auto field_axyz_ext = shamrock::solvergraph::FieldRefs<Tvec>::make_shared("", "");
91
93 [&](shamrock::solvergraph::FieldRefs<Tvec> &field_axyz_ext_edge) {
95 scheduler().for_each_patchdata_nonempty([&](const Patch p, PatchDataLayer &pdat) {
96 auto &field = pdat.get_field<Tvec>(iaxyz_ext);
97 field_axyz_ext_refs.add_obj(p.id_patch, std::ref(field));
98 });
99 field_axyz_ext_edge.set_refs(field_axyz_ext_refs);
100 });
101 set_field_axyz_ext.set_edges(field_axyz_ext);
102 set_field_axyz_ext.evaluate();
103
104 auto sizes = shamrock::solvergraph::Indexes<u32>::make_shared("", "");
105
108 sizes.indexes = {};
109 scheduler().for_each_patchdata_nonempty([&](const Patch p, PatchDataLayer &pdat) {
110 sizes.indexes.add_obj(p.id_patch, pdat.get_obj_cnt());
111 });
112 });
113 set_sizes.set_edges(sizes);
114 set_sizes.evaluate();
115
116 auto constant_G = shamrock::solvergraph::IDataEdge<Tscal>::make_shared("", "");
117 auto constant_c = shamrock::solvergraph::IDataEdge<Tscal>::make_shared("", "");
118
121 constant_G.data = solver_config.get_constant_G();
122 });
123
126 constant_c.data = solver_config.get_constant_c();
127 });
128
129 set_constant_G.set_edges(constant_G);
130 set_constant_c.set_edges(constant_c);
131
132 std::vector<std::shared_ptr<shamrock::solvergraph::INode>> add_ext_forces_seq{};
133
134 for (auto var_force : solver_config.ext_force_config.ext_forces) {
135 if (EF_PointMass *ext_force = std::get_if<EF_PointMass>(&var_force.val)) {
136
137 auto central_mass = shamrock::solvergraph::IDataEdge<Tscal>::make_shared("", "");
138 auto central_pos = shamrock::solvergraph::IDataEdge<Tvec>::make_shared("", "");
139
141 set_central_mass([cmass = ext_force->central_mass](
143 central_mass.data = cmass;
144 });
145 set_central_mass.set_edges(central_mass);
146
148 set_central_pos([&](shamrock::solvergraph::IDataEdge<Tvec> &central_pos) {
149 central_pos.data = {}; // no support for offset yet
150 });
151 set_central_pos.set_edges(central_pos);
152
153 common::modules::AddForceCentralGravPotential<Tvec> add_force_central_grav_potential;
154 add_force_central_grav_potential.set_edges(
155 constant_G, central_mass, central_pos, field_xyz, sizes, field_axyz_ext);
156
157 add_ext_forces_seq.push_back(
158 std::make_shared<shamrock::solvergraph::OperationSequence>(
159 "Point mass",
160 std::vector<std::shared_ptr<shamrock::solvergraph::INode>>{
161 shambase::to_shared(std::move(set_central_pos)),
162 shambase::to_shared(std::move(set_central_mass)),
163 shambase::to_shared(std::move(add_force_central_grav_potential))}));
164
165 } else if (EF_PN_PW *ext_force = std::get_if<EF_PN_PW>(&var_force.val)) {
166
167 auto central_mass = shamrock::solvergraph::IDataEdge<Tscal>::make_shared("", "");
168 auto central_pos = shamrock::solvergraph::IDataEdge<Tvec>::make_shared("", "");
169
171 set_central_mass([cmass = ext_force->central_mass](
173 central_mass.data = cmass;
174 });
175 set_central_mass.set_edges(central_mass);
176
178 set_central_pos([cpos = ext_force->central_pos](
180 central_pos.data = cpos;
181 });
182 set_central_pos.set_edges(central_pos);
183
184 common::modules::AddForcePaczynskiWiita<Tvec> add_force_paczynski_wiita;
185 add_force_paczynski_wiita.set_edges(
186 constant_G,
187 constant_c,
188 central_mass,
189 central_pos,
190 field_xyz,
191 sizes,
192 field_axyz_ext);
193
194 add_ext_forces_seq.push_back(
195 std::make_shared<shamrock::solvergraph::OperationSequence>(
196 "Pseudo-Newtonian PW",
197 std::vector<std::shared_ptr<shamrock::solvergraph::INode>>{
198 shambase::to_shared(std::move(set_central_pos)),
199 shambase::to_shared(std::move(set_central_mass)),
200 shambase::to_shared(std::move(add_force_paczynski_wiita))}));
201
202 } else if (EF_LenseThirring *ext_force = std::get_if<EF_LenseThirring>(&var_force.val)) {
203
204 auto central_mass = shamrock::solvergraph::IDataEdge<Tscal>::make_shared("", "");
205 auto central_pos = shamrock::solvergraph::IDataEdge<Tvec>::make_shared("", "");
206
208 set_central_mass([cmass = ext_force->central_mass](
210 central_mass.data = cmass;
211 });
212 set_central_mass.set_edges(central_mass);
213
215 set_central_pos([&](shamrock::solvergraph::IDataEdge<Tvec> &central_pos) {
216 central_pos.data = {}; // no support for offset yet
217 });
218 set_central_pos.set_edges(central_pos);
219
220 common::modules::AddForceCentralGravPotential<Tvec> add_force_central_grav_potential;
221 add_force_central_grav_potential.set_edges(
222 constant_G, central_mass, central_pos, field_xyz, sizes, field_axyz_ext);
223
224 add_ext_forces_seq.push_back(
225 std::make_shared<shamrock::solvergraph::OperationSequence>(
226 "Point mass",
227 std::vector<std::shared_ptr<shamrock::solvergraph::INode>>{
228 shambase::to_shared(std::move(set_central_pos)),
229 shambase::to_shared(std::move(set_central_mass)),
230 shambase::to_shared(std::move(add_force_central_grav_potential))}));
231
232 } else if (
233 EF_ShearingBoxForce *ext_force = std::get_if<EF_ShearingBoxForce>(&var_force.val)) {
234
235 auto eta = shamrock::solvergraph::IDataEdge<Tscal>::make_shared("", "");
238 eta.data = ext_force->eta;
239 });
240 set_eta.set_edges(eta);
241
243 add_force_shearing_box_inertial_part{};
244 add_force_shearing_box_inertial_part.set_edges(eta, field_xyz, sizes, field_axyz_ext);
245
246 add_ext_forces_seq.push_back(
247 std::make_shared<shamrock::solvergraph::OperationSequence>(
248 "Shearing box force",
249 std::vector<std::shared_ptr<shamrock::solvergraph::INode>>{
250 shambase::to_shared(std::move(set_eta)),
251 shambase::to_shared(std::move(add_force_shearing_box_inertial_part))}));
252
253 } else if (
254 EF_VerticalDiscPotential *ext_force
255 = std::get_if<EF_VerticalDiscPotential>(&var_force.val)) {
256
257 auto central_mass = shamrock::solvergraph::IDataEdge<Tscal>::make_shared("", "");
258 auto R0 = shamrock::solvergraph::IDataEdge<Tscal>::make_shared("", "");
259
261 set_central_mass([cmass = ext_force->central_mass](
263 central_mass.data = cmass;
264 });
265 set_central_mass.set_edges(central_mass);
266
268 [r = ext_force->R0](shamrock::solvergraph::IDataEdge<Tscal> &R0) {
269 R0.data = r; // no support for offset yet
270 });
271 set_R0.set_edges(R0);
272
273 common::modules::AddForceVerticalDiscPotential<Tvec> add_force_vertical_disc_potential;
274 add_force_vertical_disc_potential.set_edges(
275 constant_G, central_mass, R0, field_xyz, sizes, field_axyz_ext);
276
277 add_ext_forces_seq.push_back(
278 std::make_shared<shamrock::solvergraph::OperationSequence>(
279 "Vertical disc potential",
280 std::vector<std::shared_ptr<shamrock::solvergraph::INode>>{
281 shambase::to_shared(std::move(set_R0)),
282 shambase::to_shared(std::move(set_central_mass)),
283 shambase::to_shared(std::move(add_force_vertical_disc_potential))}));
284
285 } else if (
286 EF_VelocityDissipation *ext_force
287 = std::get_if<EF_VelocityDissipation>(&var_force.val)) {
288
289 } else {
290 shambase::throw_unimplemented("this force is not handled, yet ...");
291 }
292 }
293
294 set_constant_G.evaluate();
295 set_constant_c.evaluate();
296
297 if (add_ext_forces_seq.size() > 0) {
299 "Add external forces", std::move(add_ext_forces_seq));
300 seq.evaluate();
301 }
302}
303
304template<class T>
305std::shared_ptr<shamrock::solvergraph::INode> register_constant_set(
306 shamrock::solvergraph::SolverGraph &solver_graph, std::string name, std::function<T()> getter) {
307 solver_graph.register_edge(name, shamrock::solvergraph::IDataEdge<T>("", ""));
308
309 solver_graph.register_node(
310 "set_" + name,
313 edge.data = getter();
314 }));
315
316 solver_graph
318 "set_" + name)
319 .set_edges(solver_graph.get_edge_ptr_base(name));
320
321 return solver_graph.get_node_ptr_base("set_" + name);
322}
323
324template<class Tvec, template<class> class SPHKernel>
326
327 StackEntry stack_loc{};
328
329 sham::DeviceQueue &q = shamsys::instance::get_compute_scheduler().get_queue();
330
331 Tscal gpart_mass = solver_config.gpart_mass;
332
333 using namespace shamrock;
334 using namespace shamrock::patch;
335
336 PatchDataLayerLayout &pdl = scheduler().pdl_old();
337
338 const u32 iaxyz = pdl.get_field_idx<Tvec>("axyz");
339 const u32 ivxyz = pdl.get_field_idx<Tvec>("vxyz");
340 const u32 iaxyz_ext = pdl.get_field_idx<Tvec>("axyz_ext");
341
342 scheduler().for_each_patchdata_nonempty([&](Patch cur_p, PatchDataLayer &pdat) {
343 sham::DeviceBuffer<Tvec> &buf_axyz = pdat.get_field_buf_ref<Tvec>(iaxyz);
344 sham::DeviceBuffer<Tvec> &buf_axyz_ext = pdat.get_field_buf_ref<Tvec>(iaxyz_ext);
345
346 sham::EventList depends_list;
347 auto axyz = buf_axyz.get_write_access(depends_list);
348 auto axyz_ext = buf_axyz_ext.get_read_access(depends_list);
349
350 auto e = q.submit(depends_list, [&](sycl::handler &cgh) {
351 shambase::parallel_for(
352 cgh, pdat.get_obj_cnt(), "add ext force acc to acc", [=](u64 gid) {
353 axyz[gid] += axyz_ext[gid];
354 });
355 });
356
357 buf_axyz.complete_event_state(e);
358 buf_axyz_ext.complete_event_state(e);
359 });
360
361 if (solver_config.ext_force_config.ext_forces.empty()) {
362 return; // skip if no external forces
363 }
364
365 using SolverConfigExtForce = typename Config::ExtForceConfig;
366 using EF_PointMass = typename SolverConfigExtForce::PointMass;
367 using EF_PN_PW = typename SolverConfigExtForce::PN_PW;
368 using EF_LenseThirring = typename SolverConfigExtForce::LenseThirring;
369
370 using namespace shamrock::solvergraph;
371 SolverGraph solver_graph{};
372
373 auto set_constant_G = register_constant_set<Tscal>(solver_graph, "constant_G", [&]() {
374 return solver_config.get_constant_G();
375 });
376 auto set_constant_c = register_constant_set<Tscal>(solver_graph, "constant_c", [&]() {
377 return solver_config.get_constant_c();
378 });
379
380 bool is_G_needed = false;
381 bool is_c_needed = false;
382
383 for (auto var_force : solver_config.ext_force_config.ext_forces) {
384 if (EF_PointMass *ext_force = std::get_if<EF_PointMass>(&var_force.val)) {
385
386 } else if (EF_PN_PW *ext_force = std::get_if<EF_PN_PW>(&var_force.val)) {
387 is_G_needed = true;
388 is_c_needed = true;
389 } else if (EF_LenseThirring *ext_force = std::get_if<EF_LenseThirring>(&var_force.val)) {
390 is_G_needed = true;
391 is_c_needed = true;
392 } else if (
393 EF_ShearingBoxForce *ext_force = std::get_if<EF_ShearingBoxForce>(&var_force.val)) {
394 } else if (
395 EF_VerticalDiscPotential *ext_force
396 = std::get_if<EF_VerticalDiscPotential>(&var_force.val)) {
397 } else if (
398 EF_VelocityDissipation *ext_force
399 = std::get_if<EF_VelocityDissipation>(&var_force.val)) {
400 } else {
401 shambase::throw_unimplemented("this force is not handled, yet ...");
402 }
403 }
404
405 std::vector<std::shared_ptr<shamrock::solvergraph::INode>> add_ext_forces_seq{};
406
407 if (is_G_needed) {
408 add_ext_forces_seq.push_back(set_constant_G);
409 }
410 if (is_c_needed) {
411 add_ext_forces_seq.push_back(set_constant_c);
412 }
413
414 auto field_xyz = solver_graph.register_edge("field_xyz", FieldRefs<Tvec>("", ""));
415 auto field_vxyz = solver_graph.register_edge("field_vxyz", FieldRefs<Tvec>("", ""));
416 auto field_axyz = solver_graph.register_edge("field_axyz", FieldRefs<Tvec>("", ""));
417 auto field_sizes = solver_graph.register_edge("field_sizes", Indexes<u32>("", ""));
418
419 auto set_field_xyz = solver_graph.register_node(
420 "set_field_xyz", NodeSetEdge<FieldRefs<Tvec>>([&](FieldRefs<Tvec> &field_xyz_edge) {
421 DDPatchDataFieldRef<Tvec> field_xyz_refs = {};
422 scheduler().for_each_patchdata_nonempty([&](const Patch p, PatchDataLayer &pdat) {
423 auto &field = pdat.get_field<Tvec>(0);
424 field_xyz_refs.add_obj(p.id_patch, std::ref(field));
425 });
426 field_xyz_edge.set_refs(field_xyz_refs);
427 }));
428 shambase::get_check_ref(set_field_xyz).set_edges(field_xyz);
429
430 auto set_field_vxyz = solver_graph.register_node(
431 "set_field_vxyz", NodeSetEdge<FieldRefs<Tvec>>([&](FieldRefs<Tvec> &field_vxyz_edge) {
432 DDPatchDataFieldRef<Tvec> field_vxyz_refs = {};
433 scheduler().for_each_patchdata_nonempty([&](const Patch p, PatchDataLayer &pdat) {
434 auto &field = pdat.get_field<Tvec>(ivxyz);
435 field_vxyz_refs.add_obj(p.id_patch, std::ref(field));
436 });
437 field_vxyz_edge.set_refs(field_vxyz_refs);
438 }));
439 shambase::get_check_ref(set_field_vxyz).set_edges(field_vxyz);
440
441 auto set_field_axyz = solver_graph.register_node(
442 "set_field_axyz", NodeSetEdge<FieldRefs<Tvec>>([&](FieldRefs<Tvec> &field_axyz_edge) {
443 DDPatchDataFieldRef<Tvec> field_axyz_refs = {};
444 scheduler().for_each_patchdata_nonempty([&](const Patch p, PatchDataLayer &pdat) {
445 auto &field = pdat.get_field<Tvec>(iaxyz);
446 field_axyz_refs.add_obj(p.id_patch, std::ref(field));
447 });
448 field_axyz_edge.set_refs(field_axyz_refs);
449 }));
450 shambase::get_check_ref(set_field_axyz).set_edges(field_axyz);
451
452 auto set_field_sizes = solver_graph.register_node(
453 "set_field_sizes", NodeSetEdge<Indexes<u32>>([&](Indexes<u32> &sizes) {
454 sizes.indexes = {};
455 scheduler().for_each_patchdata_nonempty([&](const Patch p, PatchDataLayer &pdat) {
456 sizes.indexes.add_obj(p.id_patch, pdat.get_obj_cnt());
457 });
458 }));
459 shambase::get_check_ref(set_field_sizes).set_edges(field_sizes);
460
461 add_ext_forces_seq.push_back(set_field_xyz);
462 add_ext_forces_seq.push_back(set_field_vxyz);
463 add_ext_forces_seq.push_back(set_field_axyz);
464 add_ext_forces_seq.push_back(set_field_sizes);
465
466 for (u32 i = 0; i < solver_config.ext_force_config.ext_forces.size(); i++) {
467
468 auto &var_force = solver_config.ext_force_config.ext_forces[i];
469
470 std::string prefix = shambase::format("ext_force_{}_", i);
471
472 if (EF_PointMass *ext_force = std::get_if<EF_PointMass>(&var_force.val)) {
473
474 } else if (EF_PN_PW *ext_force = std::get_if<EF_PN_PW>(&var_force.val)) {
475
476 } else if (EF_LenseThirring *ext_force = std::get_if<EF_LenseThirring>(&var_force.val)) {
477
478 std::string prefix_cmass = prefix + "cmass_";
479 std::string prefix_central_pos = prefix + "central_pos_";
480 std::string prefix_a_spin = prefix + "a_spin_";
481 std::string prefix_dir_spin = prefix + "dir_spin_";
482 std::string prefix_lt = prefix + "lt_";
483
484 auto set_cmass = register_constant_set<Tscal>(solver_graph, prefix_cmass, [&]() {
485 return ext_force->central_mass;
486 });
487
488 auto set_central_pos
489 = register_constant_set<Tvec>(solver_graph, prefix_central_pos, [&]() {
490 return Tvec{0, 0, 0}; // no support for offset yet
491 });
492
493 auto set_a_spin = register_constant_set<Tscal>(solver_graph, prefix_a_spin, [&]() {
494 return ext_force->a_spin;
495 });
496
497 auto set_dir_spin = register_constant_set<Tvec>(solver_graph, prefix_dir_spin, [&]() {
498 return ext_force->dir_spin;
499 });
500
501 auto add_force_lense_thirring = solver_graph.register_node(
503 shambase::get_check_ref(add_force_lense_thirring)
504 .set_edges(
505 solver_graph.get_edge_ptr<IDataEdge<Tscal>>("constant_G"),
506 solver_graph.get_edge_ptr<IDataEdge<Tscal>>("constant_c"),
507 solver_graph.get_edge_ptr<IDataEdge<Tscal>>(prefix_cmass),
508 solver_graph.get_edge_ptr<IDataEdge<Tvec>>(prefix_central_pos),
509 solver_graph.get_edge_ptr<IDataEdge<Tscal>>(prefix_a_spin),
510 solver_graph.get_edge_ptr<IDataEdge<Tvec>>(prefix_dir_spin),
511 solver_graph.get_edge_ptr<IFieldSpan<Tvec>>("field_xyz"),
512 solver_graph.get_edge_ptr<IFieldSpan<Tvec>>("field_vxyz"),
513 solver_graph.get_edge_ptr<Indexes<u32>>("field_sizes"),
514 solver_graph.get_edge_ptr<IFieldSpan<Tvec>>("field_axyz"));
515
516 add_ext_forces_seq.push_back(set_cmass);
517 add_ext_forces_seq.push_back(set_central_pos);
518 add_ext_forces_seq.push_back(set_a_spin);
519 add_ext_forces_seq.push_back(set_dir_spin);
520 add_ext_forces_seq.push_back(solver_graph.get_node_ptr_base(prefix_lt));
521
522 } else if (
523 EF_ShearingBoxForce *ext_force = std::get_if<EF_ShearingBoxForce>(&var_force.val)) {
524
525 std::string prefix_Omega_0 = prefix + "Omega_0_";
526 std::string prefix_q = prefix + "q_";
527 std::string prefix_shearing_box = prefix + "shearing_box_";
528
529 auto set_Omega_0 = register_constant_set<Tscal>(solver_graph, prefix_Omega_0, [&]() {
530 return ext_force->Omega_0;
531 });
532
533 auto set_q = register_constant_set<Tscal>(solver_graph, prefix_q, [&]() {
534 return ext_force->q;
535 });
536
537 auto add_force_shearing_box_non_inertial = solver_graph.register_node(
538 prefix_shearing_box,
540 shambase::get_check_ref(add_force_shearing_box_non_inertial)
541 .set_edges(
542 solver_graph.get_edge_ptr<IDataEdge<Tscal>>(prefix_Omega_0),
543 solver_graph.get_edge_ptr<IDataEdge<Tscal>>(prefix_q),
544 solver_graph.get_edge_ptr<IFieldSpan<Tvec>>("field_xyz"),
545 solver_graph.get_edge_ptr<IFieldSpan<Tvec>>("field_vxyz"),
546 solver_graph.get_edge_ptr<Indexes<u32>>("field_sizes"),
547 solver_graph.get_edge_ptr<IFieldSpan<Tvec>>("field_axyz"));
548
549 add_ext_forces_seq.push_back(set_Omega_0);
550 add_ext_forces_seq.push_back(set_q);
551 add_ext_forces_seq.push_back(solver_graph.get_node_ptr_base(prefix_shearing_box));
552
553 } else if (
554 EF_VerticalDiscPotential *ext_force
555 = std::get_if<EF_VerticalDiscPotential>(&var_force.val)) {
556 } else if (
557 EF_VelocityDissipation *ext_force
558 = std::get_if<EF_VelocityDissipation>(&var_force.val)) {
559 std::string prefix_eta = prefix + "eta_";
560 std::string prefix_velocity_dissipation = prefix + "velocity_dissipation_";
561
562 auto set_eta
563 = register_constant_set<Tscal>(solver_graph, prefix_eta, [eta = ext_force->eta]() {
564 return eta;
565 });
566
567 auto add_force_velocity_dissipation = solver_graph.register_node(
568 prefix_velocity_dissipation,
570 shambase::get_check_ref(add_force_velocity_dissipation)
571 .set_edges(
572 solver_graph.get_edge_ptr<IDataEdge<Tscal>>(prefix_eta),
573 solver_graph.get_edge_ptr<IFieldSpan<Tvec>>("field_vxyz"),
574 solver_graph.get_edge_ptr<Indexes<u32>>("field_sizes"),
575 solver_graph.get_edge_ptr<IFieldSpan<Tvec>>("field_axyz"));
576
577 add_ext_forces_seq.push_back(set_eta);
578 add_ext_forces_seq.push_back(
579 solver_graph.get_node_ptr_base(prefix_velocity_dissipation));
580
581 } else {
582 shambase::throw_unimplemented("this force is not handled, yet ...");
583 }
584 }
585
586 if (add_ext_forces_seq.size() > 0) {
587 OperationSequence seq("Add external forces", std::move(add_ext_forces_seq));
588 seq.evaluate();
589 }
590}
591
592template<class Tvec, template<class> class SPHKernel>
593void shammodels::sph::modules::ExternalForces<Tvec, SPHKernel>::point_mass_accrete_particles() {
594
595 StackEntry stack_loc{};
596
597 Tscal gpart_mass = solver_config.gpart_mass;
598
599 using namespace shamrock;
600 using namespace shamrock::patch;
601
602 using SolverConfigExtForce = typename Config::ExtForceConfig;
603 using EF_PointMass = typename SolverConfigExtForce::PointMass;
604 using EF_LenseThirring = typename SolverConfigExtForce::LenseThirring;
605
606 PatchDataLayerLayout &pdl = scheduler().pdl_old();
607 const u32 ixyz = pdl.get_field_idx<Tvec>("xyz");
608 const u32 ivxyz = pdl.get_field_idx<Tvec>("vxyz");
609
610 auto dev_sched = shamsys::instance::get_compute_scheduler_ptr();
611
612 sham::DeviceQueue &q = shambase::get_check_ref(dev_sched).get_queue();
613
614 for (auto var_force : solver_config.ext_force_config.ext_forces) {
615
616 Tvec pos_accretion;
617 Tscal Racc;
618
619 if (EF_PointMass *ext_force = std::get_if<EF_PointMass>(&var_force.val)) {
620 pos_accretion = {0, 0, 0};
621 Racc = ext_force->Racc;
622 } else if (EF_PN_PW *ext_force = std::get_if<EF_PN_PW>(&var_force.val)) {
623 pos_accretion = {0, 0, 0};
624 Racc = ext_force->Racc;
625 } else if (EF_LenseThirring *ext_force = std::get_if<EF_LenseThirring>(&var_force.val)) {
626 pos_accretion = {0, 0, 0};
627 Racc = ext_force->Racc;
628 } else {
629 continue;
630 }
631
632 scheduler().for_each_patchdata_nonempty([&](Patch cur_p, PatchDataLayer &pdat) {
633 u32 Nobj = pdat.get_obj_cnt();
634
635 sham::DeviceBuffer<Tvec> &buf_xyz = pdat.get_field_buf_ref<Tvec>(ixyz);
636 sham::DeviceBuffer<Tvec> &buf_vxyz = pdat.get_field_buf_ref<Tvec>(ivxyz);
637
638 sycl::buffer<u32> not_accreted(Nobj);
639 sycl::buffer<u32> accreted(Nobj);
640
641 sham::EventList depends_list;
642 auto xyz = buf_xyz.get_read_access(depends_list);
643
644 auto e = q.submit(depends_list, [&](sycl::handler &cgh) {
645 sycl::accessor not_acc{not_accreted, cgh, sycl::write_only, sycl::no_init};
646 sycl::accessor acc{accreted, cgh, sycl::write_only, sycl::no_init};
647
648 Tvec r_sink = pos_accretion;
649 Tscal acc_rad2 = Racc * Racc;
650
651 shambase::parallel_for(cgh, Nobj, "check accretion", [=](i32 id_a) {
652 Tvec r = xyz[id_a] - r_sink;
653 bool not_accreted = sycl::dot(r, r) > acc_rad2;
654 not_acc[id_a] = (not_accreted) ? 1 : 0;
655 acc[id_a] = (!not_accreted) ? 1 : 0;
656 });
657 });
658
659 buf_xyz.complete_event_state(e);
660
661 std::tuple<std::optional<sycl::buffer<u32>>, u32> id_list_keep
662 = shamalgs::numeric::stream_compact(q.q, not_accreted, Nobj);
663
664 std::tuple<std::optional<sycl::buffer<u32>>, u32> id_list_accrete
665 = shamalgs::numeric::stream_compact(q.q, accreted, Nobj);
666
667 // sum accreted values onto sink
668
669 if (std::get<1>(id_list_accrete) > 0) {
670
671 u32 Naccrete = std::get<1>(id_list_accrete);
672
673 Tscal acc_mass = gpart_mass * Naccrete;
674
675 sham::DeviceBuffer<Tvec> pxyz_acc(Naccrete, dev_sched);
676
677 sham::EventList depends_list;
678
679 auto vxyz = buf_vxyz.get_read_access(depends_list);
680 auto accretion_p = pxyz_acc.get_write_access(depends_list);
681
682 auto e = q.submit(depends_list, [&, gpart_mass](sycl::handler &cgh) {
683 sycl::accessor id_acc{*std::get<0>(id_list_accrete), cgh, sycl::read_only};
684
685 shambase::parallel_for(
686 cgh, Naccrete, "compute sum momentum accretion", [=](i32 id_a) {
687 accretion_p[id_a] = gpart_mass * vxyz[id_acc[id_a]];
688 });
689 });
690
691 buf_vxyz.complete_event_state(e);
692 pxyz_acc.complete_event_state(e);
693
694 Tvec acc_pxyz = shamalgs::primitives::sum(dev_sched, pxyz_acc, 0, Naccrete);
695
696 logger::raw_ln("central potential accretion : += ", acc_mass);
697
698 pdat.keep_ids(*std::get<0>(id_list_keep), std::get<1>(id_list_keep));
699 }
700 });
701 }
702}
703
704using namespace shammath;
708
Adds the acceleration from a central gravitational potential (point mass).
Adds the Lense-Thirring force acceleration.
Adds the acceleration from a Paczynski Wiita (1980) pseudo-newtonian potential.
Adds the inertial part of the acceleration for a shearing box force.
Adds the non-inertial part of the acceleration for a shearing box force.
Adds the acceleration from a velocity dissipation force.
Adds the acceleration from a vertical disc potential.
constexpr const char * vxyz
3-velocity field
constexpr const char * xyz
Position field (3D coordinates).
shambase::DistributedData< PatchDataFieldRef< T > > DDPatchDataFieldRef
Alias for a DistributedData of PatchDataFieldRefs.
Node that applies a custom function to modify connected edges.
Declare a class to register and retrieve nodes and edges from a unique container.
std::uint32_t u32
32 bit unsigned integer
std::uint64_t u64
64 bit unsigned integer
std::int32_t i32
32 bit integer
A buffer allocated in USM (Unified Shared Memory).
void complete_event_state(sycl::event e) const
Complete the event state of the buffer.
T * get_write_access(sham::EventList &depends_list, SourceLocation src_loc=SourceLocation{})
Get a read-write pointer to the buffer's data.
const T * get_read_access(sham::EventList &depends_list, SourceLocation src_loc=SourceLocation{}) const
Get a read-only pointer to the buffer's data.
A SYCL queue associated with a device and a context.
sycl::queue q
The SYCL queue associated with this context.
sycl::event submit(Fct &&fct)
Submits a kernel to the SYCL queue.
Class to manage a list of SYCL events.
Definition EventList.hpp:31
iterator add_obj(u64 id, T &&obj)
Adds a new object to the collection.
void add_ext_forces()
add external forces to the particle acceleration, note that forces dependant on velocity shlould be a...
void compute_ext_forces_indep_v()
is ran once per timestep, it computes the forces that are independant of velocity
u32 get_field_idx(const std::string &field_name) const
Get the field id if matching name & type.
PatchDataLayer container class, the layout is described in patchdata_layout.
Interface for a solver graph edge representing a field as spans.
void evaluate()
Evaluate the node.
Definition INode.hpp:109
A node that applies a custom function to modify connected edges.
void set_edges(std::shared_ptr< IEdge > to_set)
Set the edges of the node.
A graph container for managing solver nodes and edges with type-safe access.
std::shared_ptr< INode > & get_node_ptr_base(const std::string &name)
Retrieve a node by name as a shared pointer to the base interface.
std::shared_ptr< T > get_edge_ptr(const std::string &name)
Get a typed shared pointer to an edge by name.
std::shared_ptr< T > register_edge(const std::string &name, T &&edge)
Register an edge with automatic type deduction and shared pointer creation.
std::shared_ptr< IEdge > & get_edge_ptr_base(const std::string &name)
Retrieve an edge by name as a shared pointer to the base interface.
std::shared_ptr< T > register_node(const std::string &name, T &&node)
Register a node with automatic type deduction and shared pointer creation.
T & get_node_ref(const std::string &name)
Get a typed reference to a node by name.
std::tuple< std::optional< sycl::buffer< u32 > >, u32 > stream_compact(sycl::queue &q, sycl::buffer< u32 > &buf_flags, u32 len)
Stream compaction algorithm.
Definition numeric.cpp:84
T sum(const sham::DeviceScheduler_ptr &sched, const sham::DeviceBuffer< T > &buf1, u32 start_id, u32 end_id)
Compute the sum of elements in a device buffer within a specified range.
namespace for basic c++ utilities
T & get_check_ref(const std::unique_ptr< T > &ptr, SourceLocation loc=SourceLocation())
Takes a std::unique_ptr and returns a reference to the object it holds. It throws a std::runtime_erro...
Definition memory.hpp:110
void throw_unimplemented(SourceLocation loc=SourceLocation{})
Throw a std::runtime_error saying that the function is unimplemented.
namespace for math utility
Definition AABB.hpp:26
namespace for the main framework
Definition __init__.py:1
void raw_ln(Types... var2)
Prints a log message with multiple arguments followed by a newline.
Definition logs.hpp:90
sph kernels
shambase::details::BasicStackEntry StackEntry
Alias for shambase::details::BasicStackEntry.
shammodels::ExtForceConfig< Tvec > ExtForceConfig
External force configuration.
Patch object that contain generic patch information.
Definition Patch.hpp:33