30template<
class Tvec,
template<
class>
class SPHKernel>
34 Tscal gpart_mass = solver_config.gpart_mass;
36 if (storage.sinks.is_empty()) {
41 using namespace shamrock::patch;
48 auto dev_sched = shamsys::instance::get_compute_scheduler_ptr();
51 std::vector<Sink> &sink_parts = storage.sinks.get();
54 bool had_accretion =
false;
55 std::string log =
"sink accretion :";
57 struct AccretionFlagBufs {
62 for (
size_t sink_id = 0; sink_id < sink_parts.size(); sink_id++) {
63 Sink &s = sink_parts[sink_id];
66 Tvec v_sink = s.velocity;
67 Tscal acc_rad2 = s.accretion_radius * s.accretion_radius;
73 u32 Nobj = pdat.get_obj_cnt();
88 const Tvec *__restrict
xyz,
89 u32 *__restrict not_acc,
90 u32 *__restrict acc) {
91 Tvec r =
xyz[id_a] - r_sink;
92 bool not_accreted = sycl::dot(r, r) > acc_rad2;
93 not_acc[id_a] = (not_accreted) ? 1 : 0;
94 acc[id_a] = (!not_accreted) ? 1 : 0;
97 accretion_flag_bufs.add_obj(
98 cur_p.
id_patch, AccretionFlagBufs{std::move(not_accreted), std::move(accreted)});
105 u32 Nobj = pdat.get_obj_cnt();
112 bufs_id_list_accrete.add_obj(cur_p.
id_patch, std::move(id_list_accrete));
116 Tscal s_acc_mass = 0;
117 Tvec s_acc_mxyz = {0, 0, 0};
118 Tvec s_acc_pxyz = {0, 0, 0};
119 Tvec s_acc_maxyz = {0, 0, 0};
120 Tvec s_acc_lxyz = {0, 0, 0};
123 u32 Nobj = pdat.get_obj_cnt();
132 if (id_list_accrete.get_size() > 0) {
135 Tscal acc_mass = gpart_mass * Naccrete;
147 [gpart_mass, r_sink, v_sink, dt](
149 const Tvec *__restrict
xyz,
150 const Tvec *__restrict
vxyz,
151 const Tvec *__restrict
axyz,
152 const u32 *__restrict id_acc,
153 Tvec *__restrict accretion_p,
154 Tvec *__restrict accretion_mr,
155 Tvec *__restrict accretion_ma,
156 Tvec *__restrict accretion_l) {
157 u32 i_a = id_acc[id_a];
161 accretion_p[id_a] = gpart_mass * v;
162 accretion_mr[id_a] = gpart_mass * r;
163 accretion_ma[id_a] = gpart_mass * a;
168 accretion_l[id_a] = gpart_mass * sycl::cross(r - r_sink, v - v_sink);
176 s_acc_mass += acc_mass;
177 s_acc_pxyz += acc_pxyz;
178 s_acc_mxyz += acc_mxyz;
179 s_acc_maxyz += acc_maxyz;
180 s_acc_lxyz += acc_lxyz;
184 Tscal sum_acc_mass = shamalgs::collective::allreduce_sum(s_acc_mass);
187 if (sum_acc_mass <= 0) {
191 Tvec sum_acc_pxyz = shamalgs::collective::allreduce_sum(s_acc_pxyz);
192 Tvec sum_acc_mxyz = shamalgs::collective::allreduce_sum(s_acc_mxyz);
193 Tvec sum_acc_maxyz = shamalgs::collective::allreduce_sum(s_acc_maxyz);
194 Tvec sum_acc_lxyz = shamalgs::collective::allreduce_sum(s_acc_lxyz);
197 Tscal new_mass = s.mass + sum_acc_mass;
198 Tvec new_pos = (sum_acc_mxyz + s.pos * s.mass) / (s.mass + sum_acc_mass);
199 Tvec new_vel = (sum_acc_pxyz + s.velocity * s.mass) / (s.mass + sum_acc_mass);
200 Tvec new_acc = (sum_acc_maxyz + s.sph_acceleration * s.mass) / (s.mass + sum_acc_mass);
201 Tvec new_ang_mom = s.angular_momentum + sum_acc_lxyz
202 - new_mass * sycl::cross(new_pos - s.pos, new_vel - s.velocity);
206 new_state.mass = new_mass;
207 new_state.pos = new_pos;
208 new_state.velocity = new_vel;
209 new_state.angular_momentum = new_ang_mom;
210 new_state.sph_acceleration = new_acc;
212 had_accretion =
true;
213 log += shambase::format(
214 "\n id {} deltas : mass={} r={} v={} l={}",
216 new_state.mass - s.mass,
217 new_state.pos - s.pos,
218 new_state.velocity - s.velocity,
219 new_state.angular_momentum - s.angular_momentum);
225 u32 Nobj = pdat.get_obj_cnt();
228 = accretion_flag_bufs.get(cur_p.
id_patch).not_accreted;
233 if (id_list_accrete.
get_size() > 0) {
245 logger::info_ln(
"sph::Sink", log);
249template<
class Tvec,
template<
class>
class SPHKernel>
254 if (storage.sinks.is_empty()) {
258 compute_ext_forces();
260 std::vector<Sink> &sink_parts = storage.sinks.get();
262 for (Sink &s : sink_parts) {
263 s.velocity += (dt / 2) * s.sph_acceleration;
266 for (Sink &s : sink_parts) {
267 s.velocity += (dt / 2) * s.ext_acceleration;
270 for (Sink &s : sink_parts) {
271 s.pos += (dt) *s.velocity;
274 for (Sink &s : sink_parts) {
275 s.velocity += (dt / 2) * s.ext_acceleration;
279template<
class Tvec,
template<
class>
class SPHKernel>
284 if (storage.sinks.is_empty()) {
288 std::vector<Sink> &sink_parts = storage.sinks.get();
290 for (Sink &s : sink_parts) {
291 s.velocity += (dt / 2) * s.sph_acceleration;
295template<
class Tvec,
template<
class>
class SPHKernel>
300 Tscal gpart_mass = solver_config.gpart_mass;
302 if (storage.sinks.is_empty()) {
306 std::vector<Sink> &sink_parts = storage.sinks.get();
308 Tscal G = solver_config.get_constant_G();
309 Tscal epsilon_grav = 1e-9;
312 using namespace shamrock::patch;
318 auto dev_sched = shamsys::instance::get_compute_scheduler_ptr();
321 std::vector<Tvec> result_acc_sinks{};
323 for (Sink &s : sink_parts) {
325 Tvec sph_acc_sink = {};
327 scheduler().for_each_patchdata_nonempty(
334 Tscal sink_mass = s.mass;
335 Tscal sink_racc = s.accretion_radius;
336 Tvec sink_pos = s.pos;
340 auto axyz_ext = buf_axyz_ext.get_write_access(depends_list);
341 auto axyz_sync = buf_sync_axyz.get_write_access(depends_list);
345 [&, G, epsilon_grav, sink_mass, sink_pos, sink_racc](sycl::handler &cgh) {
346 shambase::parallel_for(
347 cgh, pdat.get_obj_cnt(),
"sink-sph forces", [=](
i32 id_a) {
348 Tvec r_a = xyz[id_a];
350 Tvec delta = r_a - sink_pos;
351 Tscal d = sycl::length(delta);
353 Tvec force = G * delta / (d * d * d);
362 axyz_sync[id_a] = force * gpart_mass;
363 axyz_ext[id_a] += -force * sink_mass;
368 buf_axyz_ext.complete_event_state(e);
369 buf_sync_axyz.complete_event_state(e);
375 result_acc_sinks.push_back(sph_acc_sink);
378 std::vector<Tvec> gathered_result_acc_sinks{};
380 result_acc_sinks, gathered_result_acc_sinks, MPI_COMM_WORLD);
383 for (Sink &s : sink_parts) {
385 s.sph_acceleration = {};
388 s.sph_acceleration += gathered_result_acc_sinks[rid * sink_parts.size() + id_s];
395template<
class Tvec,
template<
class>
class SPHKernel>
400 if (storage.sinks.is_empty()) {
404 std::vector<Sink> &sink_parts = storage.sinks.get();
406 for (Sink &s : sink_parts) {
407 s.ext_acceleration = Tvec{};
410 Tscal G = solver_config.get_constant_G();
411 Tscal epsilon_grav_sink = 1e-9;
413 for (Sink &s1 : sink_parts) {
415 for (Sink &s2 : sink_parts) {
416 Tvec rij = s1.pos - s2.pos;
417 Tscal rij_scal = sycl::length(rij);
418 sum -= G * s2.mass * rij / (rij_scal * rij_scal * rij_scal + epsilon_grav_sink);
420 s1.ext_acceleration = sum;
constexpr const char * axyz
3-acceleration field
constexpr const char * vxyz
3-velocity field
constexpr const char * xyz
Position field (3D coordinates)
std::uint32_t u32
32 bit unsigned integer
std::int32_t i32
32 bit integer
A buffer allocated in USM (Unified Shared Memory)
void complete_event_state(sycl::event e) const
Complete the event state of the buffer.
size_t get_size() const
Gets the number of elements in the buffer.
const T * get_read_access(sham::EventList &depends_list, SourceLocation src_loc=SourceLocation{}) const
Get a read-only pointer to the buffer's data.
A SYCL queue associated with a device and a context.
sycl::event submit(Fct &&fct)
Submits a kernel to the SYCL queue.
Class to manage a list of SYCL events.
Represents a collection of objects distributed across patches identified by a u64 id.
u32 get_field_idx(const std::string &field_name) const
Get the field id if matching name & type.
PatchDataLayer container class, the layout is described in patchdata_layout.
std::vector< int > vector_allgatherv(const std::vector< T > &send_vec, const MPI_Datatype &send_type, std::vector< T > &recv_vec, const MPI_Datatype &recv_type, const MPI_Comm comm)
allgatherv on vector with size query (size querying variant of vector_allgatherv_ks) //TODO add fault...
void kernel_call(sham::DeviceQueue &q, RefIn in, RefOut in_out, u32 n, Functor &&func, SourceLocation &&callsite=SourceLocation{})
Submit a kernel to a SYCL queue.
std::tuple< std::optional< sycl::buffer< u32 > >, u32 > stream_compact(sycl::queue &q, sycl::buffer< u32 > &buf_flags, u32 len)
Stream compaction algorithm.
T sum(const sham::DeviceScheduler_ptr &sched, const sham::DeviceBuffer< T > &buf1, u32 start_id, u32 end_id)
Compute the sum of elements in a device buffer within a specified range.
void throw_with_loc(std::string message, SourceLocation loc=SourceLocation{})
Throw an exception and append the source location to it.
T & get_check_ref(const std::unique_ptr< T > &ptr, SourceLocation loc=SourceLocation())
Takes a std::unique_ptr and returns a reference to the object it holds. It throws a std::runtime_erro...
i32 world_rank()
Gives the rank of the current process in the MPI communicator.
i32 world_size()
Gives the size of the MPI communicator.
namespace for math utility
namespace for the main framework
Utilities for safe type narrowing conversions.
A class that references multiple buffers or similar objects.
Patch object that contain generic patch information.
u64 id_patch
unique key that identify the patch