Shamrock 2025.10.0
Astrophysical Code
Loading...
Searching...
No Matches
SGSFMMPlummer.cpp
Go to the documentation of this file.
1// -------------------------------------------------------//
2//
3// SHAMROCK code for hydrodynamics
4// Copyright (c) 2021-2026 Timothée David--Cléris <tim.shamrock@proton.me>
5// SPDX-License-Identifier: CeCILL Free Software License Agreement v2.1
6// Shamrock is licensed under the CeCILL 2.1 License, see LICENSE for more information
7//
8// -------------------------------------------------------//
9
24#include "shamcomm/logs.hpp"
25#include "shammath/AABB.hpp"
36
38
39 template<class Tvec, u32 mm_order>
42
43 using Umorton = u32;
45
46 auto edges = get_edges();
47
48 edges.field_axyz_ext.ensure_sizes(edges.sizes.indexes);
49
50 if (edges.sizes.indexes.get_ids().size() != 1) {
52 "Self gravity direct mode only supports one patch so far, current number "
53 "of patches is : "
54 + std::to_string(edges.sizes.indexes.get_ids().size()));
55 }
56
57 Tscal G = edges.constant_G.data;
58 Tscal gpart_mass = edges.gpart_mass.data;
59
60 Tscal gravitational_softening = epsilon * epsilon;
61
62 auto dev_sched = shamsys::instance::get_compute_scheduler_ptr();
63 sham::DeviceQueue &q = shamsys::instance::get_compute_scheduler().get_queue();
64
65 edges.sizes.indexes.for_each([&](u64 id, const u64 &n) {
66 PatchDataField<Tvec> &xyz = edges.field_xyz.get_field(id);
67 PatchDataField<Tvec> &axyz_ext = edges.field_axyz_ext.get_field(id);
68
69 Tvec bmax = xyz.compute_max();
70 Tvec bmin = xyz.compute_min();
71 shammath::AABB<Tvec> aabb(bmin, bmax);
72
73 // build the tree
74 auto bvh = RTree::make_empty(dev_sched);
75 bvh.rebuild_from_positions(xyz.get_buf(), xyz.get_obj_cnt(), aabb, reduction_level);
76
77 // compute moments in leaves
78 auto mass_moments_tree = compute_tree_mass_moments<Tvec, Umorton, mm_order - 1>(
79 bvh, xyz.get_buf(), gpart_mass);
80
81 // DTT
82 auto dtt_result = shamtree::clbvh_dual_tree_traversal(
83 dev_sched, bvh, theta_crit, true, leaf_lowering);
84
85 // M2L step
87 static constexpr u32 grav_moment_terms = GravMoments::num_component;
88
89 using MassMoments = shammath::SymTensorCollection<Tscal, 0, mm_order - 1>;
90 static constexpr u32 mass_moment_terms = MassMoments::num_component;
91
92 auto grav_moments_tree = shamtree::prepare_karras_radix_tree_field_multi_var<Tscal>(
93 bvh.structure,
94 shamtree::new_empty_karras_radix_tree_field_multi_var<Tscal>(grav_moment_terms));
95
96 // we do not need to reset grav moments tree as it will be overwritten in the M2L step
97
98 logger::raw_ln(
99 "SPH", "M2L interact count: ", dtt_result.node_interactions_m2l.get_size());
100 logger::raw_ln(
101 "SPH", "P2P interact count: ", dtt_result.node_interactions_p2p.get_size());
102
103 // M2L kernel
105 q,
107 dtt_result.node_interactions_m2l,
108 dtt_result.ordered_result->offset_m2l,
109 mass_moments_tree.buf_field,
110 bvh.aabbs.buf_aabb_min,
111 bvh.aabbs.buf_aabb_max},
112 sham::MultiRef{grav_moments_tree.buf_field},
113 bvh.structure.get_total_cell_count(),
114 [](u32 cell_id,
115 const u32_2 *m2l_interactions,
116 const u32 *offset_m2l,
117 const Tscal *mass_moments,
118 const Tvec *aabb_min,
119 const Tvec *aabb_max,
120 Tscal *grav_moments) {
121 auto load_mass_moment = [&](u32 cell_id) -> MassMoments {
122 const Tscal *mass_moment_ptr = mass_moments + cell_id * mass_moment_terms;
123 return MassMoments::load(mass_moment_ptr, 0);
124 };
125
126 auto load_aabb = [&](u32 cell_id) -> shammath::AABB<Tvec> {
127 return shammath::AABB<Tvec>{aabb_min[cell_id], aabb_max[cell_id]};
128 };
129
130 GravMoments dM_k = GravMoments::zeros();
131
132 shammath::AABB<Tvec> aabb_A = load_aabb(cell_id);
133 Tvec s_A = aabb_A.get_center();
134
135 for (u32 i = offset_m2l[cell_id]; i < offset_m2l[cell_id + 1]; i++) {
136 u32_2 interaction = m2l_interactions[i];
137 u32 cell_id_a = interaction.x();
138 u32 cell_id_b = interaction.y();
139 SHAM_ASSERT(cell_id_a == cell_id);
140
141 MassMoments Q_n_B = load_mass_moment(cell_id_b);
142
143 Tvec s_B = load_aabb(cell_id_b).get_center();
144
145 Tvec r_fmm = s_B - s_A;
146
147 auto D_n
149 r_fmm);
150
151 dM_k += shamphys::get_dM_mat(D_n, Q_n_B);
152 }
153
154 Tscal *cell_moments_ptr = grav_moments + cell_id * grav_moment_terms;
155 dM_k.store(cell_moments_ptr, 0);
156 });
157
158 // L2L step
159 auto is_moment_complete = shamtree::prepare_karras_radix_tree_field<u8>(
160 bvh.structure, shamtree::new_empty_karras_radix_tree_field<u8>());
161
162 // this one will not be fully overwritten so we need to initialize it to zeros
163 is_moment_complete.buf_field.fill(0_u8);
164
165 // set the root to 1 to start the process
166 is_moment_complete.buf_field.set_val_at_idx(0, 1);
167
168 auto traverser = bvh.structure.get_structure_traverser();
169
170 for (u32 i = 0; i < bvh.structure.tree_depth; i++) {
172 q,
173 sham::MultiRef{traverser, bvh.aabbs.buf_aabb_min, bvh.aabbs.buf_aabb_max},
174 sham::MultiRef{is_moment_complete.buf_field, grav_moments_tree.buf_field},
175 bvh.structure.get_internal_cell_count(),
176 [](u32 cell_id,
177 auto tree_traverser,
178 const Tvec *aabb_min,
179 const Tvec *aabb_max,
180 u8 *is_moment_complete,
181 Tscal *grav_moments) {
182 auto load_grav_moment = [&](u32 cell_id) -> GravMoments {
183 const Tscal *grav_moment_ptr
184 = grav_moments + cell_id * grav_moment_terms;
185 return GravMoments::load(grav_moment_ptr, 0);
186 };
187
188 auto store_grav_moment = [&](u32 cell_id, const GravMoments &grav_moment) {
189 Tscal *grav_moment_ptr = grav_moments + cell_id * grav_moment_terms;
190 grav_moment.store(grav_moment_ptr, 0);
191 };
192
193 u32 left_child = tree_traverser.get_left_child(cell_id);
194 u32 right_child = tree_traverser.get_right_child(cell_id);
195
196 // run only if is_moment_complete is 1
197 // at the end set children to 1
198 u8 should_compute = is_moment_complete[cell_id] == 1
199 && is_moment_complete[left_child] == 0
200 && is_moment_complete[right_child] == 0;
201
202 if (should_compute) {
203
204 u32 left_child = tree_traverser.get_left_child(cell_id);
205 u32 right_child = tree_traverser.get_right_child(cell_id);
206
207 Tvec s_A = shammath::AABB<Tvec>{aabb_min[cell_id], aabb_max[cell_id]}
208 .get_center();
209 Tvec s_left
210 = shammath::AABB<Tvec>{aabb_min[left_child], aabb_max[left_child]}
211 .get_center();
212 Tvec s_right
213 = shammath::AABB<Tvec>{aabb_min[right_child], aabb_max[right_child]}
214 .get_center();
215
216 // perform L2L
217 GravMoments my_moment = load_grav_moment(cell_id);
218
219 GravMoments left_moment = load_grav_moment(left_child);
220 GravMoments right_moment = load_grav_moment(right_child);
221
222 left_moment += shamphys::offset_dM_mat(my_moment, s_A, s_left);
223 right_moment += shamphys::offset_dM_mat(my_moment, s_A, s_right);
224
225 store_grav_moment(left_child, left_moment);
226 store_grav_moment(right_child, right_moment);
227
228 is_moment_complete[left_child] = 1;
229 is_moment_complete[right_child] = 1;
230 }
231 });
232 }
233
234 // L2P
235 auto cell_it = bvh.reduced_morton_set.get_leaf_cell_iterator();
237 q,
239 xyz.get_buf(),
240 cell_it,
241 bvh.aabbs.buf_aabb_min,
242 bvh.aabbs.buf_aabb_max,
243 grav_moments_tree.buf_field},
244 sham::MultiRef{axyz_ext.get_buf()},
245 bvh.structure.get_leaf_count(),
246 [leaf_offset = bvh.structure.get_internal_cell_count(),
247 G](u32 ileaf,
248 const Tvec *xyz,
249 auto cell_iter,
250 const Tvec *aabb_min,
251 const Tvec *aabb_max,
252 const Tscal *grav_moments,
253 Tvec *axyz_ext) {
254 auto load_grav_moment = [&](u32 cell_id) -> GravMoments {
255 const Tscal *grav_moment_ptr = grav_moments + cell_id * grav_moment_terms;
256 return GravMoments::load(grav_moment_ptr, 0);
257 };
258
259 u32 cell_id = ileaf + leaf_offset;
260 GravMoments dM_k = load_grav_moment(cell_id);
261
262 Tvec s_A
263 = shammath::AABB<Tvec>{aabb_min[cell_id], aabb_max[cell_id]}.get_center();
264
265 cell_iter.for_each_in_leaf_cell(ileaf, [&](u32 i) {
266 Tvec a_i = xyz[i] - s_A;
267
268 auto a_k
270
271 axyz_ext[i] += -G
272 * shamphys::contract_grav_moment_to_force<Tscal, mm_order>(
273 a_k, dM_k);
274 });
275 });
276
277 if (false) {
278 // tmp checks
279 u32 leaf_offset = bvh.structure.get_internal_cell_count();
280 auto node_it = bvh.reduced_morton_set.get_cell_iterator_host(
281 bvh.structure.buf_endrange, leaf_offset);
282
283 auto node_iter = node_it.get_read_access();
284
285 auto offset_p2p = dtt_result.ordered_result->offset_p2p.copy_to_stdvec();
286 auto p2p_interactions = dtt_result.node_interactions_p2p.copy_to_stdvec();
287
288 std::vector<std::pair<u32, u32>> part_calls = {};
289 std::set<std::pair<u32, u32>> part_calls_set = {};
290 std::vector<std::set<u32>> a_acc_by_tid = {};
291 for (u32 i = 0; i < xyz.get_obj_cnt(); i++) {
292 a_acc_by_tid.push_back(std::set<u32>());
293 }
294
295 // cpu variant of the kernel
296 for (u32 icell = 0; icell < bvh.structure.get_total_cell_count(); icell++) {
297
298 shamcomm::logs::raw_ln(
299 "cell id: ",
300 icell,
301 "runs from offset: ",
302 offset_p2p[icell],
303 "to offset: ",
304 offset_p2p[icell + 1]);
305
306 std::vector<u32> local_particles = {};
307 node_iter.for_each_in_cell(icell, [&](u32 i) {
308 local_particles.push_back(i);
309 });
310
311 shamcomm::logs::raw_ln(
312 "local particles: cell id: ", icell, " data: ", local_particles);
313
314 for (u32 j = offset_p2p[icell]; j < offset_p2p[icell + 1]; j++) {
315 u32_2 interaction = p2p_interactions[j];
316 u32 cell_id_a = interaction.x();
317 u32 cell_id_b = interaction.y();
318
319 if (icell != cell_id_a) {
321 "Cell id mismatch: " + std::to_string(icell)
322 + " != " + std::to_string(cell_id_a));
323 }
324
325 node_iter.for_each_in_cell(cell_id_a, [&](u32 i) {
326 node_iter.for_each_in_cell(cell_id_b, [&](u32 j) {
327 std::pair<u32, u32> call = {i, j};
328 part_calls.push_back(call);
329 part_calls_set.insert(call);
330 });
331
332 if (a_acc_by_tid[i].count(icell) == 0) {
333 a_acc_by_tid[i].insert(icell);
334 }
335
336 if (a_acc_by_tid[i].size() > 1) {
338 "Potential race condition detected in part_calls for particle "
339 + std::to_string(i) + " and cell " + std::to_string(icell)
340 + " current set: " + shambase::format("{}", a_acc_by_tid[i]));
341 }
342 });
343 }
344 }
345
346 shamlog_info_ln(
347 "FMM",
348 "Part calls size: ",
349 part_calls.size(),
350 "expected: ",
351 xyz.get_obj_cnt() * xyz.get_obj_cnt());
352 if (part_calls.size() != xyz.get_obj_cnt() * xyz.get_obj_cnt()) {
353 // throw shambase::make_except_with_loc<std::runtime_error>(
354 // "Part calls size mismatch");
355 }
356
357 for (size_t outer = 0; outer < part_calls.size(); ++outer) {
358 const auto call = part_calls[outer];
359
360 u32 count = part_calls_set.count(call);
361 if (count != 1) {
362 shamlog_error_ln("FMM", "Duplicate particle call detected in part_calls.");
364 "Duplicate particle call detected in part_calls for particle "
365 + std::to_string(call.first) + " and " + std::to_string(call.second)
366 + " with count " + std::to_string(count));
367 }
368
369 u32 count_reverse = part_calls_set.count({call.second, call.first});
370 if (count_reverse != 1) {
371 shamlog_error_ln(
372 "FMM", "Duplicate reverse particle call detected in part_calls.");
374 "Duplicate reverse particle call detected in part_calls for particle "
375 + std::to_string(call.second) + " and " + std::to_string(call.first)
376 + " with count " + std::to_string(count_reverse));
377 }
378 }
379
380 shamcomm::logs::raw_ln(a_acc_by_tid);
381
382 throw "";
383 }
384
385 // P2P
386 // enum mode { atomic, invert_list } p2p_mode = (leaf_lowering ? atomic : invert_list);
387 enum mode { atomic, invert_list } p2p_mode = invert_list;
388
389 if (p2p_mode == atomic) {
390 u32 leaf_offset = bvh.structure.get_internal_cell_count();
391 auto node_it = bvh.reduced_morton_set.get_cell_iterator(
392 bvh.structure.buf_endrange, leaf_offset);
394 q,
396 xyz.get_buf(),
397 node_it,
398 dtt_result.node_interactions_p2p,
399 dtt_result.ordered_result->offset_p2p},
400 sham::MultiRef{axyz_ext.get_buf()},
401 bvh.structure.get_total_cell_count(),
402 [leaf_offset, gpart_mass, G, gravitational_softening](
403 u32 icell,
404 const Tvec *xyz,
405 auto node_iter,
406 const u32_2 *p2p_interactions,
407 const u32 *offset_p2p,
408 Tvec *axyz_ext) {
409 auto start_id = offset_p2p[icell];
410 auto end_id = offset_p2p[icell + 1];
411 node_iter.for_each_in_cell(icell, [&](u32 i) {
412 Tvec f_i = {};
413
414 for (u32 j = offset_p2p[icell]; j < offset_p2p[icell + 1]; j++) {
415 u32_2 interaction = p2p_interactions[j];
416 u32 cell_id_a = interaction.x();
417 u32 cell_id_b = interaction.y();
418
419 SHAM_ASSERT(icell == cell_id_a);
420
421 node_iter.for_each_in_cell(cell_id_b, [&](u32 j) {
422 Tvec R = xyz[j] - xyz[i];
423 const Tscal r_inv = sycl::rsqrt(
424 R.x() * R.x() + R.y() * R.y() + R.z() * R.z()
425 + gravitational_softening);
426 f_i += G * gpart_mass * r_inv * r_inv * r_inv * R;
427 });
428 }
429
430 using aref = sycl::atomic_ref<
431 Tscal,
432 sycl::memory_order::relaxed,
433 sycl::memory_scope::device,
434 sycl::access::address_space::global_space>;
435
436 aref atomic_ref_x(axyz_ext[i].x());
437 atomic_ref_x += f_i.x();
438 aref atomic_ref_y(axyz_ext[i].y());
439 atomic_ref_y += f_i.y();
440 aref atomic_ref_z(axyz_ext[i].z());
441 atomic_ref_z += f_i.z();
442 });
443 });
444
445 } else if (p2p_mode == invert_list) {
446
447 u32 npart = xyz.get_obj_cnt();
448
449 sham::DeviceBuffer<u32> cells_per_particle(npart + 1, dev_sched);
450 cells_per_particle.fill(0);
451
452 u32 leaf_offset = bvh.structure.get_internal_cell_count();
453 auto node_it = bvh.reduced_morton_set.get_cell_iterator(
454 bvh.structure.buf_endrange, leaf_offset);
456 q,
457 sham::MultiRef{node_it},
458 sham::MultiRef{cells_per_particle},
459 bvh.structure.get_total_cell_count(),
460 [leaf_offset](u32 icell, auto node_iter, u32 *cells_per_particle) {
461 node_iter.for_each_in_cell(icell, [&](u32 i) {
462 using aref = sycl::atomic_ref<
463 u32,
464 sycl::memory_order::relaxed,
465 sycl::memory_scope::device,
466 sycl::access::address_space::global_space>;
467 aref atomic_ref(cells_per_particle[i]);
468 atomic_ref += 1_u32;
469 });
470 });
471
472 // logger::raw_ln("cells_per_particle: ", cells_per_particle.copy_to_stdvec());
473
474 shamalgs::primitives::scan_exclusive_sum_in_place(cells_per_particle, npart + 1);
475
476 auto &cells_per_part_offset = cells_per_particle;
477 u32 total_cells_per_particle = cells_per_part_offset.get_val_at_idx(npart);
478
479 // logger::raw_ln("cells_per_part_offset: ",
480 // cells_per_part_offset.copy_to_stdvec());
481
482 sham::DeviceBuffer<u32> cells_containing_particle(
483 total_cells_per_particle, dev_sched);
484 {
485 sham::DeviceBuffer<u32> tmp_offset(npart + 1, dev_sched);
486 tmp_offset.fill(0);
487
489 q,
490 sham::MultiRef{node_it, cells_per_part_offset},
491 sham::MultiRef{tmp_offset, cells_containing_particle},
492 bvh.structure.get_total_cell_count(),
493 [leaf_offset](
494 u32 icell,
495 auto node_iter,
496 const u32 *cells_per_part_offset,
497 u32 *tmp_offset,
498 u32 *cells_containing_particle) {
499 node_iter.for_each_in_cell(icell, [&](u32 i) {
500 using aref = sycl::atomic_ref<
501 u32,
502 sycl::memory_order::relaxed,
503 sycl::memory_scope::device,
504 sycl::access::address_space::global_space>;
505 aref atomic_ref(tmp_offset[i]);
506
507 u32 write_id
508 = atomic_ref.fetch_add(1_u32) + cells_per_part_offset[i];
509
510 cells_containing_particle[write_id] = icell;
511 });
512 });
513 }
514
515 // logger::raw_ln(
516 // "cells_containing_particle: ", cells_containing_particle.copy_to_stdvec());
517
518 shamalgs::primitives::segmented_sort_in_place(
519 cells_containing_particle, cells_per_part_offset);
520
521 // logger::raw_ln(
522 // "cells_containing_particle: ", cells_containing_particle.copy_to_stdvec());
523
525 q,
527 xyz.get_buf(),
528 node_it,
529 cells_per_part_offset,
530 cells_containing_particle,
531 dtt_result.node_interactions_p2p,
532 dtt_result.ordered_result->offset_p2p},
533 sham::MultiRef{axyz_ext.get_buf()},
534 npart,
535 [leaf_offset, gpart_mass, G, gravitational_softening](
536 u32 id_a,
537 const Tvec *xyz,
538 auto node_iter,
539 const u32 *cells_per_part_offset,
540 const u32 *cells_containing_particle,
541 const u32_2 *p2p_interactions,
542 const u32 *offset_p2p,
543 Tvec *axyz_ext) {
544 Tvec f_i = {};
545
546 u32 start_icell_idx = cells_per_part_offset[id_a];
547 u32 end_icell_idx = cells_per_part_offset[id_a + 1];
548
549 for (u32 icell_idx = start_icell_idx; icell_idx < end_icell_idx;
550 icell_idx++) {
551 u32 icell = cells_containing_particle[icell_idx];
552
553 for (u32 j = offset_p2p[icell]; j < offset_p2p[icell + 1]; j++) {
554 u32_2 interaction = p2p_interactions[j];
555 u32 cell_id_a = interaction.x();
556 u32 cell_id_b = interaction.y();
557
558 SHAM_ASSERT(icell == cell_id_a);
559
560 node_iter.for_each_in_cell(cell_id_b, [&](u32 j) {
561 Tvec R = xyz[j] - xyz[id_a];
562 const Tscal r_inv = sycl::rsqrt(
563 R.x() * R.x() + R.y() * R.y() + R.z() * R.z()
564 + gravitational_softening);
565
566 f_i += G * gpart_mass * r_inv * r_inv * r_inv * R;
567 });
568 }
569 }
570
571 axyz_ext[id_a] += f_i;
572 });
573 } else {
574 throw shambase::make_except_with_loc<std::runtime_error>("Unsupported p2p mode");
575 }
576 });
577 }
578} // namespace shammodels::sph::modules
579
Dual tree traversal algorithm for Compressed Leaf Bounding Volume Hierarchies.
std::uint8_t u8
8 bit unsigned integer
std::uint32_t u32
32 bit unsigned integer
std::uint64_t u64
64 bit unsigned integer
#define SHAM_ASSERT(x)
Shorthand for SHAM_ASSERT_NAMED without a message.
Definition assert.hpp:67
A buffer allocated in USM (Unified Shared Memory)
void fill(T value, std::array< size_t, 2 > idx_range)
Fill a subpart of the buffer with a given value.
T get_val_at_idx(size_t idx) const
Get the value at a given index in the buffer.
A SYCL queue associated with a device and a context.
DeviceQueue & get_queue(u32 id=0)
Get a reference to a DeviceQueue.
void _impl_evaluate_internal() override
evaluate the node
Utility to get the derivatives of the Green function for gravity in Cartesian coordinates.
A Compressed Leaf Bounding Volume Hierarchy (CLBVH) for neighborhood queries.
void kernel_call(sham::DeviceQueue &q, RefIn in, RefOut in_out, u32 n, Functor &&func, SourceLocation &&callsite=SourceLocation{})
Submit a kernel to a SYCL queue.
void scan_exclusive_sum_in_place(sham::DeviceBuffer< T > &buf1, u32 len)
Compute exclusive prefix sum in-place on a device buffer.
void throw_with_loc(std::string message, SourceLocation loc=SourceLocation{})
Throw an exception and append the source location to it.
namespace for the sph model modules
In-place exclusive scan (prefix sum) algorithm for device buffers.
#define __shamrock_stack_entry()
Macro to create a stack entry.
A class that references multiple buffers or similar objects.
Axis-Aligned bounding box.
Definition AABB.hpp:99
T get_center() const noexcept
Returns the center of the AABB.
Definition AABB.hpp:174