Shamrock 2025.10.0
Astrophysical Code
Loading...
Searching...
No Matches
SinkParticlesUpdate.cpp
Go to the documentation of this file.
1// -------------------------------------------------------//
2//
3// SHAMROCK code for hydrodynamics
4// Copyright (c) 2021-2026 Timothée David--Cléris <tim.shamrock@proton.me>
5// SPDX-License-Identifier: CeCILL Free Software License Agreement v2.1
6// Shamrock is licensed under the CeCILL 2.1 License, see LICENSE for more information
7//
8// -------------------------------------------------------//
9
17
25#include "shamcomm/logs.hpp"
28#include <shambackends/sycl.hpp>
29
30template<class Tvec, template<class> class SPHKernel>
31void shammodels::sph::modules::SinkParticlesUpdate<Tvec, SPHKernel>::accrete_particles(Tscal dt) {
32 StackEntry stack_loc{};
33
34 Tscal gpart_mass = solver_config.gpart_mass;
35
36 if (storage.sinks.is_empty()) {
37 return;
38 }
39
40 using namespace shamrock;
41 using namespace shamrock::patch;
42
43 PatchDataLayerLayout &pdl = scheduler().pdl_old();
44 const u32 ixyz = pdl.get_field_idx<Tvec>("xyz");
45 const u32 ivxyz = pdl.get_field_idx<Tvec>("vxyz");
46 const u32 iaxyz = pdl.get_field_idx<Tvec>("axyz");
47
48 auto dev_sched = shamsys::instance::get_compute_scheduler_ptr();
49 sham::DeviceQueue &q = shambase::get_check_ref(dev_sched).get_queue();
50
51 std::vector<Sink> &sink_parts = storage.sinks.get();
52
53 u32 sink_id = 0;
54 bool had_accretion = false;
55 std::string log = "sink accretion :";
56
57 struct AccretionFlagBufs {
58 sham::DeviceBuffer<u32> not_accreted;
59 sham::DeviceBuffer<u32> accreted;
60 };
61
62 for (size_t sink_id = 0; sink_id < sink_parts.size(); sink_id++) {
63 Sink &s = sink_parts[sink_id];
64
65 Tvec r_sink = s.pos;
66 Tvec v_sink = s.velocity;
67 Tscal acc_rad2 = s.accretion_radius * s.accretion_radius;
68
69 // flags particles for accretion
71
72 scheduler().for_each_patchdata_nonempty([&](Patch cur_p, PatchDataLayer &pdat) {
73 u32 Nobj = pdat.get_obj_cnt();
74
75 sham::DeviceBuffer<Tvec> &buf_xyz = pdat.get_field_buf_ref<Tvec>(ixyz);
76 sham::DeviceBuffer<Tvec> &buf_vxyz = pdat.get_field_buf_ref<Tvec>(ivxyz);
77
78 sham::DeviceBuffer<u32> not_accreted(Nobj, dev_sched);
79 sham::DeviceBuffer<u32> accreted(Nobj, dev_sched);
80
82 q,
83 sham::MultiRef{buf_xyz},
84 sham::MultiRef{not_accreted, accreted},
85 Nobj,
86 [r_sink, acc_rad2](
87 u32 id_a,
88 const Tvec *__restrict xyz,
89 u32 *__restrict not_acc,
90 u32 *__restrict acc) {
91 Tvec r = xyz[id_a] - r_sink;
92 bool not_accreted = sycl::dot(r, r) > acc_rad2;
93 not_acc[id_a] = (not_accreted) ? 1 : 0;
94 acc[id_a] = (!not_accreted) ? 1 : 0;
95 });
96
97 accretion_flag_bufs.add_obj(
98 cur_p.id_patch, AccretionFlagBufs{std::move(not_accreted), std::move(accreted)});
99 });
100
101 // list the ids that will be accreted
103
104 scheduler().for_each_patchdata_nonempty([&](Patch cur_p, PatchDataLayer &pdat) {
105 u32 Nobj = pdat.get_obj_cnt();
106
107 sham::DeviceBuffer<u32> &accreted = accretion_flag_bufs.get(cur_p.id_patch).accreted;
108
109 sham::DeviceBuffer<u32> id_list_accrete
110 = shamalgs::stream_compact(dev_sched, accreted, Nobj);
111
112 bufs_id_list_accrete.add_obj(cur_p.id_patch, std::move(id_list_accrete));
113 });
114
115 // compute the accreted mass, position moment and linear momentum
116 Tscal s_acc_mass = 0;
117 Tvec s_acc_mxyz = {0, 0, 0};
118 Tvec s_acc_pxyz = {0, 0, 0};
119 Tvec s_acc_maxyz = {0, 0, 0};
120 Tvec s_acc_lxyz = {0, 0, 0};
121
122 scheduler().for_each_patchdata_nonempty([&](Patch cur_p, PatchDataLayer &pdat) {
123 u32 Nobj = pdat.get_obj_cnt();
124
125 sham::DeviceBuffer<Tvec> &buf_xyz = pdat.get_field_buf_ref<Tvec>(ixyz);
126 sham::DeviceBuffer<Tvec> &buf_vxyz = pdat.get_field_buf_ref<Tvec>(ivxyz);
127 sham::DeviceBuffer<Tvec> &buf_axyz = pdat.get_field_buf_ref<Tvec>(iaxyz);
128
129 sham::DeviceBuffer<u32> &id_list_accrete = bufs_id_list_accrete.get(cur_p.id_patch);
130
131 // sum accreted values onto sink
132 if (id_list_accrete.get_size() > 0) {
133 u32 Naccrete = shambase::narrow_or_throw<u32>(id_list_accrete.get_size());
134
135 Tscal acc_mass = gpart_mass * Naccrete;
136
137 sham::DeviceBuffer<Tvec> pxyz_acc(Naccrete, dev_sched);
138 sham::DeviceBuffer<Tvec> maxyz_acc(Naccrete, dev_sched);
139 sham::DeviceBuffer<Tvec> mxyz_acc(Naccrete, dev_sched);
140 sham::DeviceBuffer<Tvec> lxyz_acc(Naccrete, dev_sched);
141
143 q,
144 sham::MultiRef{buf_xyz, buf_vxyz, buf_axyz, id_list_accrete},
145 sham::MultiRef{pxyz_acc, mxyz_acc, maxyz_acc, lxyz_acc},
146 Naccrete,
147 [gpart_mass, r_sink, v_sink, dt](
148 u32 id_a,
149 const Tvec *__restrict xyz,
150 const Tvec *__restrict vxyz,
151 const Tvec *__restrict axyz,
152 const u32 *__restrict id_acc,
153 Tvec *__restrict accretion_p,
154 Tvec *__restrict accretion_mr,
155 Tvec *__restrict accretion_ma,
156 Tvec *__restrict accretion_l) {
157 u32 i_a = id_acc[id_a];
158 Tvec r = xyz[i_a];
159 Tvec v = vxyz[i_a];
160 Tvec a = axyz[i_a];
161 accretion_p[id_a] = gpart_mass * v;
162 accretion_mr[id_a] = gpart_mass * r;
163 accretion_ma[id_a] = gpart_mass * a;
164
165 // dirty trick to account for the residual acceleration in the spin. This
166 // allows us to maitain a much better angular momentum conservation.
167 v += a * dt / 2;
168 accretion_l[id_a] = gpart_mass * sycl::cross(r - r_sink, v - v_sink);
169 });
170
171 Tvec acc_pxyz = shamalgs::primitives::sum(dev_sched, pxyz_acc, 0, Naccrete);
172 Tvec acc_mxyz = shamalgs::primitives::sum(dev_sched, mxyz_acc, 0, Naccrete);
173 Tvec acc_maxyz = shamalgs::primitives::sum(dev_sched, maxyz_acc, 0, Naccrete);
174 Tvec acc_lxyz = shamalgs::primitives::sum(dev_sched, lxyz_acc, 0, Naccrete);
175
176 s_acc_mass += acc_mass;
177 s_acc_pxyz += acc_pxyz;
178 s_acc_mxyz += acc_mxyz;
179 s_acc_maxyz += acc_maxyz;
180 s_acc_lxyz += acc_lxyz;
181 }
182 });
183
184 Tscal sum_acc_mass = shamalgs::collective::allreduce_sum(s_acc_mass);
185
186 // if there is accretion continue otherwise skip that part
187 if (sum_acc_mass <= 0) {
188 continue;
189 }
190
191 Tvec sum_acc_pxyz = shamalgs::collective::allreduce_sum(s_acc_pxyz);
192 Tvec sum_acc_mxyz = shamalgs::collective::allreduce_sum(s_acc_mxyz);
193 Tvec sum_acc_maxyz = shamalgs::collective::allreduce_sum(s_acc_maxyz);
194 Tvec sum_acc_lxyz = shamalgs::collective::allreduce_sum(s_acc_lxyz);
195
196 // compute the new sink values
197 Tscal new_mass = s.mass + sum_acc_mass;
198 Tvec new_pos = (sum_acc_mxyz + s.pos * s.mass) / (s.mass + sum_acc_mass);
199 Tvec new_vel = (sum_acc_pxyz + s.velocity * s.mass) / (s.mass + sum_acc_mass);
200 Tvec new_acc = (sum_acc_maxyz + s.sph_acceleration * s.mass) / (s.mass + sum_acc_mass);
201 Tvec new_ang_mom = s.angular_momentum + sum_acc_lxyz
202 - new_mass * sycl::cross(new_pos - s.pos, new_vel - s.velocity);
203
204 // write back the updated sink state
205 auto new_state = s;
206 new_state.mass = new_mass;
207 new_state.pos = new_pos;
208 new_state.velocity = new_vel;
209 new_state.angular_momentum = new_ang_mom;
210 new_state.sph_acceleration = new_acc;
211
212 had_accretion = true;
213 log += shambase::format(
214 "\n id {} deltas : mass={} r={} v={} l={}",
215 sink_id,
216 new_state.mass - s.mass,
217 new_state.pos - s.pos,
218 new_state.velocity - s.velocity,
219 new_state.angular_momentum - s.angular_momentum);
220
221 s = new_state;
222
223 // evict accreted particles from patches
224 scheduler().for_each_patchdata_nonempty([&](Patch cur_p, PatchDataLayer &pdat) {
225 u32 Nobj = pdat.get_obj_cnt();
226
227 sham::DeviceBuffer<u32> &not_accreted
228 = accretion_flag_bufs.get(cur_p.id_patch).not_accreted;
229 sham::DeviceBuffer<u32> &accreted = accretion_flag_bufs.get(cur_p.id_patch).accreted;
230
231 sham::DeviceBuffer<u32> &id_list_accrete = bufs_id_list_accrete.get(cur_p.id_patch);
232
233 if (id_list_accrete.get_size() > 0) {
234
235 sham::DeviceBuffer<u32> id_list_keep
236 = shamalgs::stream_compact(dev_sched, not_accreted, Nobj);
237
238 pdat.keep_ids(
239 id_list_keep, shambase::narrow_or_throw<u32>(id_list_keep.get_size()));
240 }
241 });
242 }
243
244 if (shamcomm::world_rank() == 0 && had_accretion) {
245 logger::info_ln("sph::Sink", log);
246 }
247}
248
249template<class Tvec, template<class> class SPHKernel>
250void shammodels::sph::modules::SinkParticlesUpdate<Tvec, SPHKernel>::predictor_step(Tscal dt) {
251
252 StackEntry stack_loc{};
253
254 if (storage.sinks.is_empty()) {
255 return;
256 }
257
258 compute_ext_forces();
259
260 std::vector<Sink> &sink_parts = storage.sinks.get();
261
262 for (Sink &s : sink_parts) {
263 s.velocity += (dt / 2) * (s.sph_acceleration + s.ext_acceleration);
264 }
265
266 for (Sink &s : sink_parts) {
267 s.pos += (dt) *s.velocity;
268 }
269}
270
271template<class Tvec, template<class> class SPHKernel>
272void shammodels::sph::modules::SinkParticlesUpdate<Tvec, SPHKernel>::corrector_step(Tscal dt) {
273
274 StackEntry stack_loc{};
275
276 if (storage.sinks.is_empty()) {
277 return;
278 }
279
280 std::vector<Sink> &sink_parts = storage.sinks.get();
281
282 for (Sink &s : sink_parts) {
283 s.velocity += (dt / 2) * (s.sph_acceleration + s.ext_acceleration);
284 }
285}
286
287template<class Tvec, template<class> class SPHKernel>
288void shammodels::sph::modules::SinkParticlesUpdate<Tvec, SPHKernel>::compute_sph_forces() {
289
290 StackEntry stack_loc{};
291
292 Tscal gpart_mass = solver_config.gpart_mass;
293
294 if (storage.sinks.is_empty()) {
295 return;
296 }
297
298 std::vector<Sink> &sink_parts = storage.sinks.get();
299
300 Tscal G = solver_config.get_constant_G();
301 Tscal epsilon_grav = 1e-9;
302
303 using namespace shamrock;
304 using namespace shamrock::patch;
305
306 PatchDataLayerLayout &pdl = scheduler().pdl_old();
307 const u32 ixyz = pdl.get_field_idx<Tvec>("xyz");
308 const u32 iaxyz_ext = pdl.get_field_idx<Tvec>("axyz_ext");
309
310 auto dev_sched = shamsys::instance::get_compute_scheduler_ptr();
311 sham::DeviceQueue &q = shambase::get_check_ref(dev_sched).get_queue();
312
313 std::vector<Tvec> result_acc_sinks{};
314
315 for (Sink &s : sink_parts) {
316
317 Tvec sph_acc_sink = {};
318
319 scheduler().for_each_patchdata_nonempty(
320 [&, G, epsilon_grav, gpart_mass](Patch cur_p, PatchDataLayer &pdat) {
321 sham::DeviceBuffer<Tvec> &buf_xyz = pdat.get_field_buf_ref<Tvec>(ixyz);
322 sham::DeviceBuffer<Tvec> &buf_axyz_ext = pdat.get_field_buf_ref<Tvec>(iaxyz_ext);
323
324 sham::DeviceBuffer<Tvec> buf_sync_axyz(pdat.get_obj_cnt(), dev_sched);
325
326 Tscal sink_mass = s.mass;
327 Tscal sink_racc = s.accretion_radius;
328 Tvec sink_pos = s.pos;
329
330 sham::EventList depends_list;
331 auto xyz = buf_xyz.get_read_access(depends_list);
332 auto axyz_ext = buf_axyz_ext.get_write_access(depends_list);
333 auto axyz_sync = buf_sync_axyz.get_write_access(depends_list);
334
335 auto e = q.submit(
336 depends_list,
337 [&, G, epsilon_grav, sink_mass, sink_pos, sink_racc](sycl::handler &cgh) {
338 shambase::parallel_for(
339 cgh, pdat.get_obj_cnt(), "sink-sph forces", [=](i32 id_a) {
340 Tvec r_a = xyz[id_a];
341
342 Tvec delta = r_a - sink_pos;
343 Tscal d = sycl::length(delta);
344
345 Tvec force = G * delta / (d * d * d);
346
347 // This is a hack to avoid the sink kaboom effect
348 // when the particle is being advected close to the sink before
349 // being accreted
350 if (d < sink_racc) {
351 force = {0, 0, 0};
352 }
353
354 axyz_sync[id_a] = force * gpart_mass;
355 axyz_ext[id_a] += -force * sink_mass;
356 });
357 });
358
359 buf_xyz.complete_event_state(e);
360 buf_axyz_ext.complete_event_state(e);
361 buf_sync_axyz.complete_event_state(e);
362
363 sph_acc_sink
364 += shamalgs::primitives::sum(dev_sched, buf_sync_axyz, 0, pdat.get_obj_cnt());
365 });
366
367 result_acc_sinks.push_back(sph_acc_sink);
368 }
369
370 std::vector<Tvec> gathered_result_acc_sinks{};
372 result_acc_sinks, gathered_result_acc_sinks, MPI_COMM_WORLD);
373
374 u32 id_s = 0;
375 for (Sink &s : sink_parts) {
376
377 s.sph_acceleration = {};
378
379 for (u32 rid = 0; rid < shamcomm::world_size(); rid++) {
380 s.sph_acceleration += gathered_result_acc_sinks[rid * sink_parts.size() + id_s];
381 }
382
383 id_s++;
384 }
385}
386
387template<class Tvec, template<class> class SPHKernel>
388void shammodels::sph::modules::SinkParticlesUpdate<Tvec, SPHKernel>::compute_ext_forces() {
389
390 StackEntry stack_loc{};
391
392 if (storage.sinks.is_empty()) {
393 return;
394 }
395
396 std::vector<Sink> &sink_parts = storage.sinks.get();
397
398 for (Sink &s : sink_parts) {
399 s.ext_acceleration = Tvec{};
400 }
401
402 Tscal G = solver_config.get_constant_G();
403 Tscal epsilon_grav_sink = 1e-9;
404
405 for (Sink &s1 : sink_parts) {
406 Tvec sum{};
407 for (Sink &s2 : sink_parts) {
408 Tvec rij = s1.pos - s2.pos;
409 Tscal rij_scal = sycl::length(rij);
410 sum -= G * s2.mass * rij / (rij_scal * rij_scal * rij_scal + epsilon_grav_sink);
411 }
412 s1.ext_acceleration = sum;
413 }
414}
415
416using namespace shammath;
420
constexpr const char * axyz
3-acceleration field
constexpr const char * vxyz
3-velocity field
constexpr const char * xyz
Position field (3D coordinates).
std::uint32_t u32
32 bit unsigned integer
std::int32_t i32
32 bit integer
A buffer allocated in USM (Unified Shared Memory).
void complete_event_state(sycl::event e) const
Complete the event state of the buffer.
size_t get_size() const
Gets the number of elements in the buffer.
const T * get_read_access(sham::EventList &depends_list, SourceLocation src_loc=SourceLocation{}) const
Get a read-only pointer to the buffer's data.
A SYCL queue associated with a device and a context.
sycl::event submit(Fct &&fct)
Submits a kernel to the SYCL queue.
Class to manage a list of SYCL events.
Definition EventList.hpp:31
Represents a collection of objects distributed across patches identified by a u64 id.
iterator add_obj(u64 id, T &&obj)
Adds a new object to the collection.
T & get(u64 id)
Returns a reference to an object in the collection.
u32 get_field_idx(const std::string &field_name) const
Get the field id if matching name & type.
PatchDataLayer container class, the layout is described in patchdata_layout.
std::vector< int > vector_allgatherv(const std::vector< T > &send_vec, const MPI_Datatype &send_type, std::vector< T > &recv_vec, const MPI_Datatype &recv_type, const MPI_Comm comm)
allgatherv on vector with size query (size querying variant of vector_allgatherv_ks) //TODO add fault...
Definition exchanges.hpp:98
void kernel_call(sham::DeviceQueue &q, RefIn in, RefOut in_out, u32 n, Functor &&func, SourceLocation &&callsite=SourceLocation{})
Submit a kernel to a SYCL queue.
std::tuple< std::optional< sycl::buffer< u32 > >, u32 > stream_compact(sycl::queue &q, sycl::buffer< u32 > &buf_flags, u32 len)
Stream compaction algorithm.
Definition numeric.cpp:84
T sum(const sham::DeviceScheduler_ptr &sched, const sham::DeviceBuffer< T > &buf1, u32 start_id, u32 end_id)
Compute the sum of elements in a device buffer within a specified range.
T & get_check_ref(const std::unique_ptr< T > &ptr, SourceLocation loc=SourceLocation())
Takes a std::unique_ptr and returns a reference to the object it holds. It throws a std::runtime_erro...
Definition memory.hpp:110
i32 world_rank()
Gives the rank of the current process in the MPI communicator.
Definition worldInfo.cpp:40
i32 world_size()
Gives the size of the MPI communicator.
Definition worldInfo.cpp:38
namespace for math utility
Definition AABB.hpp:26
namespace for the main framework
Definition __init__.py:1
Utilities for safe type narrowing conversions.
void info_ln(std::string module_name, Types... var2)
Prints a log message with multiple arguments followed by a newline.
Definition logs.hpp:133
sph kernels
shambase::details::BasicStackEntry StackEntry
Alias for shambase::details::BasicStackEntry.
A class that references multiple buffers or similar objects.
Patch object that contain generic patch information.
Definition Patch.hpp:33
u64 id_patch
unique key that identify the patch
Definition Patch.hpp:86