Shamrock 2025.10.0
Astrophysical Code
Loading...
Searching...
No Matches
SinkParticlesUpdate.cpp
Go to the documentation of this file.
1// -------------------------------------------------------//
2//
3// SHAMROCK code for hydrodynamics
4// Copyright (c) 2021-2026 Timothée David--Cléris <tim.shamrock@proton.me>
5// SPDX-License-Identifier: CeCILL Free Software License Agreement v2.1
6// Shamrock is licensed under the CeCILL 2.1 License, see LICENSE for more information
7//
8// -------------------------------------------------------//
9
25#include "shamcomm/logs.hpp"
28#include <shambackends/sycl.hpp>
29
30template<class Tvec, template<class> class SPHKernel>
32 StackEntry stack_loc{};
33
34 Tscal gpart_mass = solver_config.gpart_mass;
35
36 if (storage.sinks.is_empty()) {
37 return;
38 }
39
40 using namespace shamrock;
41 using namespace shamrock::patch;
42
43 PatchDataLayerLayout &pdl = scheduler().pdl_old();
44 const u32 ixyz = pdl.get_field_idx<Tvec>("xyz");
45 const u32 ivxyz = pdl.get_field_idx<Tvec>("vxyz");
46 const u32 iaxyz = pdl.get_field_idx<Tvec>("axyz");
47
48 auto dev_sched = shamsys::instance::get_compute_scheduler_ptr();
49 sham::DeviceQueue &q = shambase::get_check_ref(dev_sched).get_queue();
50
51 std::vector<Sink> &sink_parts = storage.sinks.get();
52
53 u32 sink_id = 0;
54 bool had_accretion = false;
55 std::string log = "sink accretion :";
56
57 struct AccretionFlagBufs {
58 sham::DeviceBuffer<u32> not_accreted;
60 };
61
62 for (size_t sink_id = 0; sink_id < sink_parts.size(); sink_id++) {
63 Sink &s = sink_parts[sink_id];
64
65 Tvec r_sink = s.pos;
66 Tvec v_sink = s.velocity;
67 Tscal acc_rad2 = s.accretion_radius * s.accretion_radius;
68
69 // flags particles for accretion
71
72 scheduler().for_each_patchdata_nonempty([&](Patch cur_p, PatchDataLayer &pdat) {
73 u32 Nobj = pdat.get_obj_cnt();
74
75 sham::DeviceBuffer<Tvec> &buf_xyz = pdat.get_field_buf_ref<Tvec>(ixyz);
76 sham::DeviceBuffer<Tvec> &buf_vxyz = pdat.get_field_buf_ref<Tvec>(ivxyz);
77
78 sham::DeviceBuffer<u32> not_accreted(Nobj, dev_sched);
79 sham::DeviceBuffer<u32> accreted(Nobj, dev_sched);
80
82 q,
83 sham::MultiRef{buf_xyz},
84 sham::MultiRef{not_accreted, accreted},
85 Nobj,
86 [r_sink, acc_rad2](
87 u32 id_a,
88 const Tvec *__restrict xyz,
89 u32 *__restrict not_acc,
90 u32 *__restrict acc) {
91 Tvec r = xyz[id_a] - r_sink;
92 bool not_accreted = sycl::dot(r, r) > acc_rad2;
93 not_acc[id_a] = (not_accreted) ? 1 : 0;
94 acc[id_a] = (!not_accreted) ? 1 : 0;
95 });
96
97 accretion_flag_bufs.add_obj(
98 cur_p.id_patch, AccretionFlagBufs{std::move(not_accreted), std::move(accreted)});
99 });
100
101 // list the ids that will be accreted
103
104 scheduler().for_each_patchdata_nonempty([&](Patch cur_p, PatchDataLayer &pdat) {
105 u32 Nobj = pdat.get_obj_cnt();
106
107 sham::DeviceBuffer<u32> &accreted = accretion_flag_bufs.get(cur_p.id_patch).accreted;
108
109 sham::DeviceBuffer<u32> id_list_accrete
110 = shamalgs::stream_compact(dev_sched, accreted, Nobj);
111
112 bufs_id_list_accrete.add_obj(cur_p.id_patch, std::move(id_list_accrete));
113 });
114
115 // compute the accreted mass, position moment and linear momentum
116 Tscal s_acc_mass = 0;
117 Tvec s_acc_mxyz = {0, 0, 0};
118 Tvec s_acc_pxyz = {0, 0, 0};
119 Tvec s_acc_maxyz = {0, 0, 0};
120 Tvec s_acc_lxyz = {0, 0, 0};
121
122 scheduler().for_each_patchdata_nonempty([&](Patch cur_p, PatchDataLayer &pdat) {
123 u32 Nobj = pdat.get_obj_cnt();
124
125 sham::DeviceBuffer<Tvec> &buf_xyz = pdat.get_field_buf_ref<Tvec>(ixyz);
126 sham::DeviceBuffer<Tvec> &buf_vxyz = pdat.get_field_buf_ref<Tvec>(ivxyz);
127 sham::DeviceBuffer<Tvec> &buf_axyz = pdat.get_field_buf_ref<Tvec>(iaxyz);
128
129 sham::DeviceBuffer<u32> &id_list_accrete = bufs_id_list_accrete.get(cur_p.id_patch);
130
131 // sum accreted values onto sink
132 if (id_list_accrete.get_size() > 0) {
133 u32 Naccrete = shambase::narrow_or_throw<u32>(id_list_accrete.get_size());
134
135 Tscal acc_mass = gpart_mass * Naccrete;
136
137 sham::DeviceBuffer<Tvec> pxyz_acc(Naccrete, dev_sched);
138 sham::DeviceBuffer<Tvec> maxyz_acc(Naccrete, dev_sched);
139 sham::DeviceBuffer<Tvec> mxyz_acc(Naccrete, dev_sched);
140 sham::DeviceBuffer<Tvec> lxyz_acc(Naccrete, dev_sched);
141
143 q,
144 sham::MultiRef{buf_xyz, buf_vxyz, buf_axyz, id_list_accrete},
145 sham::MultiRef{pxyz_acc, mxyz_acc, maxyz_acc, lxyz_acc},
146 Naccrete,
147 [gpart_mass, r_sink, v_sink, dt](
148 u32 id_a,
149 const Tvec *__restrict xyz,
150 const Tvec *__restrict vxyz,
151 const Tvec *__restrict axyz,
152 const u32 *__restrict id_acc,
153 Tvec *__restrict accretion_p,
154 Tvec *__restrict accretion_mr,
155 Tvec *__restrict accretion_ma,
156 Tvec *__restrict accretion_l) {
157 u32 i_a = id_acc[id_a];
158 Tvec r = xyz[i_a];
159 Tvec v = vxyz[i_a];
160 Tvec a = axyz[i_a];
161 accretion_p[id_a] = gpart_mass * v;
162 accretion_mr[id_a] = gpart_mass * r;
163 accretion_ma[id_a] = gpart_mass * a;
164
165 // dirty trick to account for the residual acceleration in the spin. This
166 // allows us to maitain a much better angular momentum conservation.
167 v += a * dt / 2;
168 accretion_l[id_a] = gpart_mass * sycl::cross(r - r_sink, v - v_sink);
169 });
170
171 Tvec acc_pxyz = shamalgs::primitives::sum(dev_sched, pxyz_acc, 0, Naccrete);
172 Tvec acc_mxyz = shamalgs::primitives::sum(dev_sched, mxyz_acc, 0, Naccrete);
173 Tvec acc_maxyz = shamalgs::primitives::sum(dev_sched, maxyz_acc, 0, Naccrete);
174 Tvec acc_lxyz = shamalgs::primitives::sum(dev_sched, lxyz_acc, 0, Naccrete);
175
176 s_acc_mass += acc_mass;
177 s_acc_pxyz += acc_pxyz;
178 s_acc_mxyz += acc_mxyz;
179 s_acc_maxyz += acc_maxyz;
180 s_acc_lxyz += acc_lxyz;
181 }
182 });
183
184 Tscal sum_acc_mass = shamalgs::collective::allreduce_sum(s_acc_mass);
185
186 // if there is accretion continue otherwise skip that part
187 if (sum_acc_mass <= 0) {
188 continue;
189 }
190
191 Tvec sum_acc_pxyz = shamalgs::collective::allreduce_sum(s_acc_pxyz);
192 Tvec sum_acc_mxyz = shamalgs::collective::allreduce_sum(s_acc_mxyz);
193 Tvec sum_acc_maxyz = shamalgs::collective::allreduce_sum(s_acc_maxyz);
194 Tvec sum_acc_lxyz = shamalgs::collective::allreduce_sum(s_acc_lxyz);
195
196 // compute the new sink values
197 Tscal new_mass = s.mass + sum_acc_mass;
198 Tvec new_pos = (sum_acc_mxyz + s.pos * s.mass) / (s.mass + sum_acc_mass);
199 Tvec new_vel = (sum_acc_pxyz + s.velocity * s.mass) / (s.mass + sum_acc_mass);
200 Tvec new_acc = (sum_acc_maxyz + s.sph_acceleration * s.mass) / (s.mass + sum_acc_mass);
201 Tvec new_ang_mom = s.angular_momentum + sum_acc_lxyz
202 - new_mass * sycl::cross(new_pos - s.pos, new_vel - s.velocity);
203
204 // write back the updated sink state
205 auto new_state = s;
206 new_state.mass = new_mass;
207 new_state.pos = new_pos;
208 new_state.velocity = new_vel;
209 new_state.angular_momentum = new_ang_mom;
210 new_state.sph_acceleration = new_acc;
211
212 had_accretion = true;
213 log += shambase::format(
214 "\n id {} deltas : mass={} r={} v={} l={}",
215 sink_id,
216 new_state.mass - s.mass,
217 new_state.pos - s.pos,
218 new_state.velocity - s.velocity,
219 new_state.angular_momentum - s.angular_momentum);
220
221 s = new_state;
222
223 // evict accreted particles from patches
224 scheduler().for_each_patchdata_nonempty([&](Patch cur_p, PatchDataLayer &pdat) {
225 u32 Nobj = pdat.get_obj_cnt();
226
227 sham::DeviceBuffer<u32> &not_accreted
228 = accretion_flag_bufs.get(cur_p.id_patch).not_accreted;
229 sham::DeviceBuffer<u32> &accreted = accretion_flag_bufs.get(cur_p.id_patch).accreted;
230
231 sham::DeviceBuffer<u32> &id_list_accrete = bufs_id_list_accrete.get(cur_p.id_patch);
232
233 if (id_list_accrete.get_size() > 0) {
234
235 sham::DeviceBuffer<u32> id_list_keep
236 = shamalgs::stream_compact(dev_sched, not_accreted, Nobj);
237
238 pdat.keep_ids(
239 id_list_keep, shambase::narrow_or_throw<u32>(id_list_keep.get_size()));
240 }
241 });
242 }
243
244 if (shamcomm::world_rank() == 0 && had_accretion) {
245 logger::info_ln("sph::Sink", log);
246 }
247}
248
249template<class Tvec, template<class> class SPHKernel>
251
252 StackEntry stack_loc{};
253
254 if (storage.sinks.is_empty()) {
255 return;
256 }
257
258 compute_ext_forces();
259
260 std::vector<Sink> &sink_parts = storage.sinks.get();
261
262 for (Sink &s : sink_parts) {
263 s.velocity += (dt / 2) * s.sph_acceleration;
264 }
265
266 for (Sink &s : sink_parts) {
267 s.velocity += (dt / 2) * s.ext_acceleration;
268 }
269
270 for (Sink &s : sink_parts) {
271 s.pos += (dt) *s.velocity;
272 }
273
274 for (Sink &s : sink_parts) {
275 s.velocity += (dt / 2) * s.ext_acceleration;
276 }
277}
278
279template<class Tvec, template<class> class SPHKernel>
281
282 StackEntry stack_loc{};
283
284 if (storage.sinks.is_empty()) {
285 return;
286 }
287
288 std::vector<Sink> &sink_parts = storage.sinks.get();
289
290 for (Sink &s : sink_parts) {
291 s.velocity += (dt / 2) * s.sph_acceleration;
292 }
293}
294
295template<class Tvec, template<class> class SPHKernel>
297
298 StackEntry stack_loc{};
299
300 Tscal gpart_mass = solver_config.gpart_mass;
301
302 if (storage.sinks.is_empty()) {
303 return;
304 }
305
306 std::vector<Sink> &sink_parts = storage.sinks.get();
307
308 Tscal G = solver_config.get_constant_G();
309 Tscal epsilon_grav = 1e-9;
310
311 using namespace shamrock;
312 using namespace shamrock::patch;
313
314 PatchDataLayerLayout &pdl = scheduler().pdl_old();
315 const u32 ixyz = pdl.get_field_idx<Tvec>("xyz");
316 const u32 iaxyz_ext = pdl.get_field_idx<Tvec>("axyz_ext");
317
318 auto dev_sched = shamsys::instance::get_compute_scheduler_ptr();
319 sham::DeviceQueue &q = shambase::get_check_ref(dev_sched).get_queue();
320
321 std::vector<Tvec> result_acc_sinks{};
322
323 for (Sink &s : sink_parts) {
324
325 Tvec sph_acc_sink = {};
326
327 scheduler().for_each_patchdata_nonempty(
328 [&, G, epsilon_grav, gpart_mass](Patch cur_p, PatchDataLayer &pdat) {
329 sham::DeviceBuffer<Tvec> &buf_xyz = pdat.get_field_buf_ref<Tvec>(ixyz);
330 sham::DeviceBuffer<Tvec> &buf_axyz_ext = pdat.get_field_buf_ref<Tvec>(iaxyz_ext);
331
332 sham::DeviceBuffer<Tvec> buf_sync_axyz(pdat.get_obj_cnt(), dev_sched);
333
334 Tscal sink_mass = s.mass;
335 Tscal sink_racc = s.accretion_radius;
336 Tvec sink_pos = s.pos;
337
338 sham::EventList depends_list;
339 auto xyz = buf_xyz.get_read_access(depends_list);
340 auto axyz_ext = buf_axyz_ext.get_write_access(depends_list);
341 auto axyz_sync = buf_sync_axyz.get_write_access(depends_list);
342
343 auto e = q.submit(
344 depends_list,
345 [&, G, epsilon_grav, sink_mass, sink_pos, sink_racc](sycl::handler &cgh) {
346 shambase::parallel_for(
347 cgh, pdat.get_obj_cnt(), "sink-sph forces", [=](i32 id_a) {
348 Tvec r_a = xyz[id_a];
349
350 Tvec delta = r_a - sink_pos;
351 Tscal d = sycl::length(delta);
352
353 Tvec force = G * delta / (d * d * d);
354
355 // This is a hack to avoid the sink kaboom effect
356 // when the particle is being advected close to the sink before
357 // being accreted
358 if (d < sink_racc) {
359 force = {0, 0, 0};
360 }
361
362 axyz_sync[id_a] = force * gpart_mass;
363 axyz_ext[id_a] += -force * sink_mass;
364 });
365 });
366
367 buf_xyz.complete_event_state(e);
368 buf_axyz_ext.complete_event_state(e);
369 buf_sync_axyz.complete_event_state(e);
370
371 sph_acc_sink
372 += shamalgs::primitives::sum(dev_sched, buf_sync_axyz, 0, pdat.get_obj_cnt());
373 });
374
375 result_acc_sinks.push_back(sph_acc_sink);
376 }
377
378 std::vector<Tvec> gathered_result_acc_sinks{};
380 result_acc_sinks, gathered_result_acc_sinks, MPI_COMM_WORLD);
381
382 u32 id_s = 0;
383 for (Sink &s : sink_parts) {
384
385 s.sph_acceleration = {};
386
387 for (u32 rid = 0; rid < shamcomm::world_size(); rid++) {
388 s.sph_acceleration += gathered_result_acc_sinks[rid * sink_parts.size() + id_s];
389 }
390
391 id_s++;
392 }
393}
394
395template<class Tvec, template<class> class SPHKernel>
397
398 StackEntry stack_loc{};
399
400 if (storage.sinks.is_empty()) {
401 return;
402 }
403
404 std::vector<Sink> &sink_parts = storage.sinks.get();
405
406 for (Sink &s : sink_parts) {
407 s.ext_acceleration = Tvec{};
408 }
409
410 Tscal G = solver_config.get_constant_G();
411 Tscal epsilon_grav_sink = 1e-9;
412
413 for (Sink &s1 : sink_parts) {
414 Tvec sum{};
415 for (Sink &s2 : sink_parts) {
416 Tvec rij = s1.pos - s2.pos;
417 Tscal rij_scal = sycl::length(rij);
418 sum -= G * s2.mass * rij / (rij_scal * rij_scal * rij_scal + epsilon_grav_sink);
419 }
420 s1.ext_acceleration = sum;
421 }
422}
423
424using namespace shammath;
428
constexpr const char * axyz
3-acceleration field
constexpr const char * vxyz
3-velocity field
constexpr const char * xyz
Position field (3D coordinates)
std::uint32_t u32
32 bit unsigned integer
std::int32_t i32
32 bit integer
A buffer allocated in USM (Unified Shared Memory)
void complete_event_state(sycl::event e) const
Complete the event state of the buffer.
size_t get_size() const
Gets the number of elements in the buffer.
const T * get_read_access(sham::EventList &depends_list, SourceLocation src_loc=SourceLocation{}) const
Get a read-only pointer to the buffer's data.
A SYCL queue associated with a device and a context.
sycl::event submit(Fct &&fct)
Submits a kernel to the SYCL queue.
Class to manage a list of SYCL events.
Definition EventList.hpp:31
Represents a collection of objects distributed across patches identified by a u64 id.
u32 get_field_idx(const std::string &field_name) const
Get the field id if matching name & type.
PatchDataLayer container class, the layout is described in patchdata_layout.
std::vector< int > vector_allgatherv(const std::vector< T > &send_vec, const MPI_Datatype &send_type, std::vector< T > &recv_vec, const MPI_Datatype &recv_type, const MPI_Comm comm)
allgatherv on vector with size query (size querying variant of vector_allgatherv_ks) //TODO add fault...
Definition exchanges.hpp:98
void kernel_call(sham::DeviceQueue &q, RefIn in, RefOut in_out, u32 n, Functor &&func, SourceLocation &&callsite=SourceLocation{})
Submit a kernel to a SYCL queue.
std::tuple< std::optional< sycl::buffer< u32 > >, u32 > stream_compact(sycl::queue &q, sycl::buffer< u32 > &buf_flags, u32 len)
Stream compaction algorithm.
Definition numeric.cpp:84
T sum(const sham::DeviceScheduler_ptr &sched, const sham::DeviceBuffer< T > &buf1, u32 start_id, u32 end_id)
Compute the sum of elements in a device buffer within a specified range.
void throw_with_loc(std::string message, SourceLocation loc=SourceLocation{})
Throw an exception and append the source location to it.
T & get_check_ref(const std::unique_ptr< T > &ptr, SourceLocation loc=SourceLocation())
Takes a std::unique_ptr and returns a reference to the object it holds. It throws a std::runtime_erro...
Definition memory.hpp:110
i32 world_rank()
Gives the rank of the current process in the MPI communicator.
Definition worldInfo.cpp:40
i32 world_size()
Gives the size of the MPI communicator.
Definition worldInfo.cpp:38
namespace for math utility
Definition AABB.hpp:26
namespace for the main framework
Definition __init__.py:1
Utilities for safe type narrowing conversions.
sph kernels
A class that references multiple buffers or similar objects.
Patch object that contain generic patch information.
Definition Patch.hpp:33
u64 id_patch
unique key that identify the patch
Definition Patch.hpp:86