46 auto edges = get_edges();
48 edges.field_axyz_ext.ensure_sizes(edges.sizes.indexes);
50 if (edges.sizes.indexes.get_ids().size() != 1) {
52 "Self gravity direct mode only supports one patch so far, current number "
54 + std::to_string(edges.sizes.indexes.get_ids().size()));
57 Tscal G = edges.constant_G.data;
58 Tscal gpart_mass = edges.gpart_mass.data;
60 Tscal gravitational_softening = epsilon * epsilon;
62 auto dev_sched = shamsys::instance::get_compute_scheduler_ptr();
65 edges.sizes.indexes.for_each([&](
u64 id,
const u64 &n) {
69 Tvec bmax = xyz.compute_max();
70 Tvec bmin = xyz.compute_min();
74 auto bvh = RTree::make_empty(dev_sched);
75 bvh.rebuild_from_positions(xyz.get_buf(), xyz.get_obj_cnt(), aabb, reduction_level);
78 auto mass_moments_tree = compute_tree_mass_moments<Tvec, Umorton, mm_order - 1>(
79 bvh, xyz.get_buf(), gpart_mass);
82 auto dtt_result = shamtree::clbvh_dual_tree_traversal(
83 dev_sched, bvh, theta_crit,
true, leaf_lowering);
87 static constexpr u32 grav_moment_terms = GravMoments::num_component;
90 static constexpr u32 mass_moment_terms = MassMoments::num_component;
92 auto grav_moments_tree = shamtree::prepare_karras_radix_tree_field_multi_var<Tscal>(
94 shamtree::new_empty_karras_radix_tree_field_multi_var<Tscal>(grav_moment_terms));
99 "SPH",
"M2L interact count: ", dtt_result.node_interactions_m2l.get_size());
101 "SPH",
"P2P interact count: ", dtt_result.node_interactions_p2p.get_size());
107 dtt_result.node_interactions_m2l,
108 dtt_result.ordered_result->offset_m2l,
109 mass_moments_tree.buf_field,
110 bvh.aabbs.buf_aabb_min,
111 bvh.aabbs.buf_aabb_max},
113 bvh.structure.get_total_cell_count(),
115 const u32_2 *m2l_interactions,
116 const u32 *offset_m2l,
117 const Tscal *mass_moments,
118 const Tvec *aabb_min,
119 const Tvec *aabb_max,
120 Tscal *grav_moments) {
121 auto load_mass_moment = [&](
u32 cell_id) -> MassMoments {
122 const Tscal *mass_moment_ptr = mass_moments + cell_id * mass_moment_terms;
123 return MassMoments::load(mass_moment_ptr, 0);
130 GravMoments dM_k = GravMoments::zeros();
135 for (
u32 i = offset_m2l[cell_id]; i < offset_m2l[cell_id + 1]; i++) {
136 u32_2 interaction = m2l_interactions[i];
137 u32 cell_id_a = interaction.x();
138 u32 cell_id_b = interaction.y();
141 MassMoments Q_n_B = load_mass_moment(cell_id_b);
143 Tvec s_B = load_aabb(cell_id_b).get_center();
145 Tvec r_fmm = s_B - s_A;
151 dM_k += shamphys::get_dM_mat(D_n, Q_n_B);
154 Tscal *cell_moments_ptr = grav_moments + cell_id * grav_moment_terms;
155 dM_k.store(cell_moments_ptr, 0);
159 auto is_moment_complete = shamtree::prepare_karras_radix_tree_field<u8>(
160 bvh.structure, shamtree::new_empty_karras_radix_tree_field<u8>());
163 is_moment_complete.buf_field.fill(0_u8);
166 is_moment_complete.buf_field.set_val_at_idx(0, 1);
168 auto traverser = bvh.structure.get_structure_traverser();
170 for (
u32 i = 0; i < bvh.structure.tree_depth; i++) {
173 sham::MultiRef{traverser, bvh.aabbs.buf_aabb_min, bvh.aabbs.buf_aabb_max},
174 sham::MultiRef{is_moment_complete.buf_field, grav_moments_tree.buf_field},
175 bvh.structure.get_internal_cell_count(),
178 const Tvec *aabb_min,
179 const Tvec *aabb_max,
180 u8 *is_moment_complete,
181 Tscal *grav_moments) {
182 auto load_grav_moment = [&](
u32 cell_id) -> GravMoments {
183 const Tscal *grav_moment_ptr
184 = grav_moments + cell_id * grav_moment_terms;
185 return GravMoments::load(grav_moment_ptr, 0);
188 auto store_grav_moment = [&](
u32 cell_id,
const GravMoments &grav_moment) {
189 Tscal *grav_moment_ptr = grav_moments + cell_id * grav_moment_terms;
190 grav_moment.store(grav_moment_ptr, 0);
193 u32 left_child = tree_traverser.get_left_child(cell_id);
194 u32 right_child = tree_traverser.get_right_child(cell_id);
198 u8 should_compute = is_moment_complete[cell_id] == 1
199 && is_moment_complete[left_child] == 0
200 && is_moment_complete[right_child] == 0;
202 if (should_compute) {
204 u32 left_child = tree_traverser.get_left_child(cell_id);
205 u32 right_child = tree_traverser.get_right_child(cell_id);
217 GravMoments my_moment = load_grav_moment(cell_id);
219 GravMoments left_moment = load_grav_moment(left_child);
220 GravMoments right_moment = load_grav_moment(right_child);
222 left_moment += shamphys::offset_dM_mat(my_moment, s_A, s_left);
223 right_moment += shamphys::offset_dM_mat(my_moment, s_A, s_right);
225 store_grav_moment(left_child, left_moment);
226 store_grav_moment(right_child, right_moment);
228 is_moment_complete[left_child] = 1;
229 is_moment_complete[right_child] = 1;
235 auto cell_it = bvh.reduced_morton_set.get_leaf_cell_iterator();
241 bvh.aabbs.buf_aabb_min,
242 bvh.aabbs.buf_aabb_max,
243 grav_moments_tree.buf_field},
245 bvh.structure.get_leaf_count(),
246 [leaf_offset = bvh.structure.get_internal_cell_count(),
250 const Tvec *aabb_min,
251 const Tvec *aabb_max,
252 const Tscal *grav_moments,
254 auto load_grav_moment = [&](u32 cell_id) -> GravMoments {
255 const Tscal *grav_moment_ptr = grav_moments + cell_id * grav_moment_terms;
256 return GravMoments::load(grav_moment_ptr, 0);
259 u32 cell_id = ileaf + leaf_offset;
260 GravMoments dM_k = load_grav_moment(cell_id);
265 cell_iter.for_each_in_leaf_cell(ileaf, [&](
u32 i) {
266 Tvec a_i = xyz[i] - s_A;
272 * shamphys::contract_grav_moment_to_force<Tscal, mm_order>(
279 u32 leaf_offset = bvh.structure.get_internal_cell_count();
280 auto node_it = bvh.reduced_morton_set.get_cell_iterator_host(
281 bvh.structure.buf_endrange, leaf_offset);
283 auto node_iter = node_it.get_read_access();
285 auto offset_p2p = dtt_result.ordered_result->offset_p2p.copy_to_stdvec();
286 auto p2p_interactions = dtt_result.node_interactions_p2p.copy_to_stdvec();
288 std::vector<std::pair<u32, u32>> part_calls = {};
289 std::set<std::pair<u32, u32>> part_calls_set = {};
290 std::vector<std::set<u32>> a_acc_by_tid = {};
291 for (
u32 i = 0; i < xyz.get_obj_cnt(); i++) {
292 a_acc_by_tid.push_back(std::set<u32>());
296 for (
u32 icell = 0; icell < bvh.structure.get_total_cell_count(); icell++) {
298 shamcomm::logs::raw_ln(
301 "runs from offset: ",
304 offset_p2p[icell + 1]);
306 std::vector<u32> local_particles = {};
307 node_iter.for_each_in_cell(icell, [&](
u32 i) {
308 local_particles.push_back(i);
311 shamcomm::logs::raw_ln(
312 "local particles: cell id: ", icell,
" data: ", local_particles);
314 for (
u32 j = offset_p2p[icell]; j < offset_p2p[icell + 1]; j++) {
315 u32_2 interaction = p2p_interactions[j];
316 u32 cell_id_a = interaction.x();
317 u32 cell_id_b = interaction.y();
319 if (icell != cell_id_a) {
321 "Cell id mismatch: " + std::to_string(icell)
322 +
" != " + std::to_string(cell_id_a));
325 node_iter.for_each_in_cell(cell_id_a, [&](
u32 i) {
326 node_iter.for_each_in_cell(cell_id_b, [&](
u32 j) {
327 std::pair<u32, u32> call = {i, j};
328 part_calls.push_back(call);
329 part_calls_set.insert(call);
332 if (a_acc_by_tid[i].count(icell) == 0) {
333 a_acc_by_tid[i].insert(icell);
336 if (a_acc_by_tid[i].size() > 1) {
338 "Potential race condition detected in part_calls for particle "
339 + std::to_string(i) +
" and cell " + std::to_string(icell)
340 +
" current set: " + shambase::format(
"{}", a_acc_by_tid[i]));
351 xyz.get_obj_cnt() * xyz.get_obj_cnt());
352 if (part_calls.size() != xyz.get_obj_cnt() * xyz.get_obj_cnt()) {
357 for (
size_t outer = 0; outer < part_calls.size(); ++outer) {
358 const auto call = part_calls[outer];
360 u32 count = part_calls_set.count(call);
362 shamlog_error_ln(
"FMM",
"Duplicate particle call detected in part_calls.");
364 "Duplicate particle call detected in part_calls for particle "
365 + std::to_string(call.first) +
" and " + std::to_string(call.second)
366 +
" with count " + std::to_string(count));
369 u32 count_reverse = part_calls_set.count({call.second, call.first});
370 if (count_reverse != 1) {
372 "FMM",
"Duplicate reverse particle call detected in part_calls.");
374 "Duplicate reverse particle call detected in part_calls for particle "
375 + std::to_string(call.second) +
" and " + std::to_string(call.first)
376 +
" with count " + std::to_string(count_reverse));
380 shamcomm::logs::raw_ln(a_acc_by_tid);
387 enum mode { atomic, invert_list } p2p_mode = invert_list;
389 if (p2p_mode == atomic) {
390 u32 leaf_offset = bvh.structure.get_internal_cell_count();
391 auto node_it = bvh.reduced_morton_set.get_cell_iterator(
392 bvh.structure.buf_endrange, leaf_offset);
398 dtt_result.node_interactions_p2p,
399 dtt_result.ordered_result->offset_p2p},
401 bvh.structure.get_total_cell_count(),
402 [leaf_offset, gpart_mass, G, gravitational_softening](
406 const u32_2 *p2p_interactions,
407 const u32 *offset_p2p,
409 auto start_id = offset_p2p[icell];
410 auto end_id = offset_p2p[icell + 1];
411 node_iter.for_each_in_cell(icell, [&](
u32 i) {
414 for (
u32 j = offset_p2p[icell]; j < offset_p2p[icell + 1]; j++) {
415 u32_2 interaction = p2p_interactions[j];
416 u32 cell_id_a = interaction.x();
417 u32 cell_id_b = interaction.y();
421 node_iter.for_each_in_cell(cell_id_b, [&](
u32 j) {
422 Tvec R = xyz[j] - xyz[i];
423 const Tscal r_inv = sycl::rsqrt(
424 R.x() * R.x() + R.y() * R.y() + R.z() * R.z()
425 + gravitational_softening);
426 f_i += G * gpart_mass * r_inv * r_inv * r_inv * R;
430 using aref = sycl::atomic_ref<
432 sycl::memory_order::relaxed,
433 sycl::memory_scope::device,
434 sycl::access::address_space::global_space>;
436 aref atomic_ref_x(axyz_ext[i].x());
437 atomic_ref_x += f_i.x();
438 aref atomic_ref_y(axyz_ext[i].y());
439 atomic_ref_y += f_i.y();
440 aref atomic_ref_z(axyz_ext[i].z());
441 atomic_ref_z += f_i.z();
445 }
else if (p2p_mode == invert_list) {
447 u32 npart = xyz.get_obj_cnt();
450 cells_per_particle.
fill(0);
452 u32 leaf_offset = bvh.structure.get_internal_cell_count();
453 auto node_it = bvh.reduced_morton_set.get_cell_iterator(
454 bvh.structure.buf_endrange, leaf_offset);
459 bvh.structure.get_total_cell_count(),
460 [leaf_offset](
u32 icell,
auto node_iter,
u32 *cells_per_particle) {
461 node_iter.for_each_in_cell(icell, [&](
u32 i) {
462 using aref = sycl::atomic_ref<
464 sycl::memory_order::relaxed,
465 sycl::memory_scope::device,
466 sycl::access::address_space::global_space>;
467 aref atomic_ref(cells_per_particle[i]);
476 auto &cells_per_part_offset = cells_per_particle;
483 total_cells_per_particle, dev_sched);
492 bvh.structure.get_total_cell_count(),
496 const u32 *cells_per_part_offset,
498 u32 *cells_containing_particle) {
499 node_iter.for_each_in_cell(icell, [&](
u32 i) {
500 using aref = sycl::atomic_ref<
502 sycl::memory_order::relaxed,
503 sycl::memory_scope::device,
504 sycl::access::address_space::global_space>;
505 aref atomic_ref(tmp_offset[i]);
508 = atomic_ref.fetch_add(1_u32) + cells_per_part_offset[i];
510 cells_containing_particle[write_id] = icell;
518 shamalgs::primitives::segmented_sort_in_place(
519 cells_containing_particle, cells_per_part_offset);
529 cells_per_part_offset,
530 cells_containing_particle,
531 dtt_result.node_interactions_p2p,
532 dtt_result.ordered_result->offset_p2p},
535 [leaf_offset, gpart_mass, G, gravitational_softening](
539 const u32 *cells_per_part_offset,
540 const u32 *cells_containing_particle,
541 const u32_2 *p2p_interactions,
542 const u32 *offset_p2p,
546 u32 start_icell_idx = cells_per_part_offset[id_a];
547 u32 end_icell_idx = cells_per_part_offset[id_a + 1];
549 for (
u32 icell_idx = start_icell_idx; icell_idx < end_icell_idx;
551 u32 icell = cells_containing_particle[icell_idx];
553 for (
u32 j = offset_p2p[icell]; j < offset_p2p[icell + 1]; j++) {
554 u32_2 interaction = p2p_interactions[j];
555 u32 cell_id_a = interaction.x();
556 u32 cell_id_b = interaction.y();
560 node_iter.for_each_in_cell(cell_id_b, [&](
u32 j) {
561 Tvec R = xyz[j] - xyz[id_a];
562 const Tscal r_inv = sycl::rsqrt(
563 R.x() * R.x() + R.y() * R.y() + R.z() * R.z()
564 + gravitational_softening);
566 f_i += G * gpart_mass * r_inv * r_inv * r_inv * R;
571 axyz_ext[id_a] += f_i;