Shamrock 2025.10.0
Astrophysical Code
Loading...
Searching...
No Matches
dtt_parallel_select.hpp
Go to the documentation of this file.
1// -------------------------------------------------------//
2//
3// SHAMROCK code for hydrodynamics
4// Copyright (c) 2021-2026 Timothée David--Cléris <tim.shamrock@proton.me>
5// SPDX-License-Identifier: CeCILL Free Software License Agreement v2.1
6// Shamrock is licensed under the CeCILL 2.1 License, see LICENSE for more information
7//
8// -------------------------------------------------------//
9
10#pragma once
11
20#include "shambase/memory.hpp"
23#include "shambackends/vec.hpp"
24#include "shammath/AABB.hpp"
28
29namespace shamtree::details {
30
31 template<class Tmorton, class Tvec, u32 dim>
33
34 using Tscal = shambase::VecComponent<Tvec>;
35
36 inline static bool mac(shammath::AABB<Tvec> a, shammath::AABB<Tvec> b, Tscal theta_crit) {
37 return shamtree::details::mac(a, b, theta_crit);
38 }
39
40 template<bool allow_leaf_lowering>
41 inline static shamtree::DTTResult dtt_internal(
42 sham::DeviceScheduler_ptr dev_sched,
44 shambase::VecComponent<Tvec> theta_crit,
45 bool ordered_result) {
46 StackEntry stack_loc{};
47
48 auto q = shambase::get_check_ref(dev_sched).get_queue();
49
51 ObjectIterator obj_it = bvh.get_object_iterator();
52
53 using ObjItAcc = typename ObjectIterator::acc;
54
55 u32 total_cell_count = bvh.structure.get_total_cell_count();
56
57 sham::DeviceBuffer<u32> count_m2l(total_cell_count + 1, dev_sched);
58 sham::DeviceBuffer<u32> count_p2p(total_cell_count + 1, dev_sched);
59 count_m2l.set_val_at_idx(total_cell_count, 0);
60 count_p2p.set_val_at_idx(total_cell_count, 0);
61
62 // count the number of interactions for each cell
63
65 q,
66 sham::MultiRef{obj_it},
67 sham::MultiRef{count_m2l, count_p2p},
68 total_cell_count,
69 [theta_crit](
70 u32 i,
71 ObjItAcc obj_it,
72 u32 *__restrict__ count_m2l,
73 u32 *__restrict__ count_p2p) {
75 = {obj_it.tree_traverser.aabb_min[i], obj_it.tree_traverser.aabb_max[i]};
76
77 auto is_kdnode_within_node = [&](u32 node_id) -> bool {
78 shammath::AABB<Tvec> aabb_node
79 = {obj_it.tree_traverser.aabb_min[node_id],
80 obj_it.tree_traverser.aabb_max[node_id]};
81
82 return aabb_node.contains(aabb_i);
83 };
84
86
87 auto &ttrav = obj_it.tree_traverser.tree_traverser;
88
89 u32 count_m2l_i = 0;
90 u32 count_p2p_i = 0;
91
92 // Am I a leaf before we start going down the tree ?
93 if (i == 0 && ttrav.is_id_leaf(0)) {
94 count_p2p_i++;
95 } else {
96 // push root-root interact on stack
97 // We make the assumption that the root is not a leaf
98 stack.push({0, 0});
99 }
100
101 while (stack.is_not_empty()) {
102 u32_2 t = stack.pop_ret();
103 u32 a = t.x();
104 u32 b = t.y();
105
106 bool is_a_i_same = a == i;
107
108 shammath::AABB<Tvec> aabb_a = {
109 obj_it.tree_traverser.aabb_min[a], obj_it.tree_traverser.aabb_max[a]};
110 shammath::AABB<Tvec> aabb_b = {
111 obj_it.tree_traverser.aabb_min[b], obj_it.tree_traverser.aabb_max[b]};
112
113 bool crit = mac(aabb_a, aabb_b, theta_crit) == false;
114
115 if (crit) {
116
117 if constexpr (allow_leaf_lowering) {
118
119 bool is_a_leaf = ttrav.is_id_leaf(a);
120 bool is_b_leaf = ttrav.is_id_leaf(b);
121
122 if (is_a_leaf && is_b_leaf) {
123 if (is_a_i_same) {
124 count_p2p_i++;
125 }
126 continue;
127 }
128
129 u32 child_a_1 = (is_a_leaf) ? a : ttrav.get_left_child(a);
130 u32 child_a_2 = (is_a_leaf) ? a : ttrav.get_right_child(a);
131 u32 child_b_1 = (is_b_leaf) ? b : ttrav.get_left_child(b);
132 u32 child_b_2 = (is_b_leaf) ? b : ttrav.get_right_child(b);
133
134 bool run_a_1 = true;
135 bool run_a_2 = !is_a_leaf;
136 bool run_b_1 = true;
137 bool run_b_2 = !is_b_leaf;
138
139 // now since we can re-enqueue the same node we need to escape only
140 // if the child is enqueued so we replace the is_a_i_same condition
141 // from the case without lowering by is_a_i_same && (child_a_1 != a)
142 if (is_a_i_same && (child_a_1 != a)) {
143 continue;
144 }
145
146 bool is_node_i_in_left_a = is_kdnode_within_node(child_a_1);
147 bool is_node_i_in_right_a = is_kdnode_within_node(child_a_2);
148
149 run_a_1 = run_a_1 && is_node_i_in_left_a;
150 run_a_2 = run_a_2 && is_node_i_in_right_a;
151
152 if (run_a_1 && run_b_1)
153 stack.push({child_a_1, child_b_1});
154 if (run_a_2 && run_b_1)
155 stack.push({child_a_2, child_b_1});
156 if (run_a_1 && run_b_2)
157 stack.push({child_a_1, child_b_2});
158 if (run_a_2 && run_b_2)
159 stack.push({child_a_2, child_b_2});
160
161 } else {
162
163 u32 child_a_1 = ttrav.get_left_child(a);
164 u32 child_a_2 = ttrav.get_right_child(a);
165 u32 child_b_1 = ttrav.get_left_child(b);
166 u32 child_b_2 = ttrav.get_right_child(b);
167
168 bool child_a_1_leaf = ttrav.is_id_leaf(child_a_1);
169 bool child_a_2_leaf = ttrav.is_id_leaf(child_a_2);
170 bool child_b_1_leaf = ttrav.is_id_leaf(child_b_1);
171 bool child_b_2_leaf = ttrav.is_id_leaf(child_b_2);
172
173 if ((child_a_1_leaf || child_a_2_leaf || child_b_1_leaf
174 || child_b_2_leaf)) {
175 if (is_a_i_same) {
176 count_p2p_i++; // found leaf-leaf interaction so skip child
177 // enqueue
178 }
179 continue;
180 }
181
182 bool is_node_i_in_left_a = is_kdnode_within_node(child_a_1);
183 bool is_node_i_in_right_a = is_kdnode_within_node(child_a_2);
184
185 if (is_a_i_same) {
186 continue;
187 }
188
189 if (is_node_i_in_left_a) {
190 stack.push({child_a_1, child_b_1});
191 stack.push({child_a_1, child_b_2});
192 }
193 if (is_node_i_in_right_a) {
194 stack.push({child_a_2, child_b_1});
195 stack.push({child_a_2, child_b_2});
196 }
197 }
198
199 } else {
200 if (is_a_i_same) {
201 count_m2l_i++;
202 }
203 }
204 }
205
206 count_m2l[i] = count_m2l_i;
207 count_p2p[i] = count_p2p_i;
208 });
209
211
212 // scans the counts
214 = shamalgs::numeric::scan_exclusive(dev_sched, count_m2l, total_cell_count + 1);
216 = shamalgs::numeric::scan_exclusive(dev_sched, count_p2p, total_cell_count + 1);
217
218 // alloc results buffers
219 u32 total_count_m2l = scan_m2l.get_val_at_idx(total_cell_count);
220 u32 total_count_p2p = scan_p2p.get_val_at_idx(total_cell_count);
221
222 sham::DeviceBuffer<u32_2> idx_m2l(total_count_m2l, dev_sched);
223 sham::DeviceBuffer<u32_2> idx_p2p(total_count_p2p, dev_sched);
224
225 // relaunch the previous kernel but write the indexes this time
226
228 q,
229 sham::MultiRef{obj_it, scan_m2l, scan_p2p},
230 sham::MultiRef{idx_m2l, idx_p2p},
231 total_cell_count,
232 [theta_crit](
233 u32 i,
234 ObjItAcc obj_it,
235 const u32 *__restrict__ scan_m2l,
236 const u32 *__restrict__ scan_p2p,
237 u32_2 *__restrict__ idx_m2l,
238 u32_2 *__restrict__ idx_p2p) {
239 u32 offset_m2l = scan_m2l[i];
240 u32 offset_p2p = scan_p2p[i];
241
243 = {obj_it.tree_traverser.aabb_min[i], obj_it.tree_traverser.aabb_max[i]};
244
245 auto is_kdnode_within_node = [&](u32 node_id) -> bool {
246 shammath::AABB<Tvec> aabb_node
247 = {obj_it.tree_traverser.aabb_min[node_id],
248 obj_it.tree_traverser.aabb_max[node_id]};
249
250 return aabb_node.contains(aabb_i);
251 };
252
254
255 auto &ttrav = obj_it.tree_traverser.tree_traverser;
256
257 // Am I a leaf before we start going down the tree ?
258 if (i == 0 && ttrav.is_id_leaf(0)) {
259 idx_p2p[offset_p2p] = {0, 0};
260 offset_p2p++;
261 } else {
262 // push root-root interact on stack
263 // We make the assumption that the root is not a leaf
264 stack.push({0, 0});
265 }
266
267 while (stack.is_not_empty()) {
268 u32_2 t = stack.pop_ret();
269 u32 a = t.x();
270 u32 b = t.y();
271
272 bool is_a_i_same = a == i;
273
274 shammath::AABB<Tvec> aabb_a = {
275 obj_it.tree_traverser.aabb_min[a], obj_it.tree_traverser.aabb_max[a]};
276 shammath::AABB<Tvec> aabb_b = {
277 obj_it.tree_traverser.aabb_min[b], obj_it.tree_traverser.aabb_max[b]};
278
279 bool crit = mac(aabb_a, aabb_b, theta_crit) == false;
280
281 if (crit) {
282
283 if constexpr (allow_leaf_lowering) {
284
285 bool is_a_leaf = ttrav.is_id_leaf(a);
286 bool is_b_leaf = ttrav.is_id_leaf(b);
287
288 if (is_a_leaf && is_b_leaf) {
289 if (is_a_i_same) {
290 idx_p2p[offset_p2p] = {a, b};
291 offset_p2p++;
292 }
293 continue;
294 }
295
296 u32 child_a_1 = (is_a_leaf) ? a : ttrav.get_left_child(a);
297 u32 child_a_2 = (is_a_leaf) ? a : ttrav.get_right_child(a);
298 u32 child_b_1 = (is_b_leaf) ? b : ttrav.get_left_child(b);
299 u32 child_b_2 = (is_b_leaf) ? b : ttrav.get_right_child(b);
300
301 bool run_a_1 = true;
302 bool run_a_2 = !is_a_leaf;
303 bool run_b_1 = true;
304 bool run_b_2 = !is_b_leaf;
305
306 // now since we can re-enqueue the same node we need to escape only
307 // if the child is enqueued so we replace the is_a_i_same condition
308 // from the case without lowering by is_a_i_same && (child_a_1 != a)
309 if (is_a_i_same && (child_a_1 != a)) {
310 continue;
311 }
312
313 bool is_node_i_in_left_a = is_kdnode_within_node(child_a_1);
314 bool is_node_i_in_right_a = is_kdnode_within_node(child_a_2);
315
316 run_a_1 = run_a_1 && is_node_i_in_left_a;
317 run_a_2 = run_a_2 && is_node_i_in_right_a;
318
319 if (run_a_1 && run_b_1)
320 stack.push({child_a_1, child_b_1});
321 if (run_a_2 && run_b_1)
322 stack.push({child_a_2, child_b_1});
323 if (run_a_1 && run_b_2)
324 stack.push({child_a_1, child_b_2});
325 if (run_a_2 && run_b_2)
326 stack.push({child_a_2, child_b_2});
327
328 } else {
329 u32 child_a_1 = ttrav.get_left_child(a);
330 u32 child_a_2 = ttrav.get_right_child(a);
331 u32 child_b_1 = ttrav.get_left_child(b);
332 u32 child_b_2 = ttrav.get_right_child(b);
333
334 bool child_a_1_leaf = ttrav.is_id_leaf(child_a_1);
335 bool child_a_2_leaf = ttrav.is_id_leaf(child_a_2);
336 bool child_b_1_leaf = ttrav.is_id_leaf(child_b_1);
337 bool child_b_2_leaf = ttrav.is_id_leaf(child_b_2);
338
339 if ((child_a_1_leaf || child_a_2_leaf || child_b_1_leaf
340 || child_b_2_leaf)) {
341 if (is_a_i_same) {
342 idx_p2p[offset_p2p] = {a, b};
343 offset_p2p++;
344 }
345 continue;
346 }
347
348 bool is_node_i_in_left_a = is_kdnode_within_node(child_a_1);
349 bool is_node_i_in_right_a = is_kdnode_within_node(child_a_2);
350
351 if (is_a_i_same) {
352 continue;
353 }
354
355 if (is_node_i_in_left_a) {
356 stack.push({child_a_1, child_b_1});
357 stack.push({child_a_1, child_b_2});
358 }
359 if (is_node_i_in_right_a) {
360 stack.push({child_a_2, child_b_1});
361 stack.push({child_a_2, child_b_2});
362 }
363 }
364
365 } else {
366 if (is_a_i_same) {
367 idx_m2l[offset_m2l] = {a, b};
368 offset_m2l++;
369 }
370 }
371 }
372 });
373
374 DTTResult ret{std::move(idx_m2l), std::move(idx_p2p)};
375
376 if (ordered_result) {
377 DTTResult::OrderedResult ordering{std::move(scan_m2l), std::move(scan_p2p)};
378 ret.ordered_result = std::move(ordering);
379 }
380
381 return ret;
382 }
383
384 inline static shamtree::DTTResult dtt(
385 sham::DeviceScheduler_ptr dev_sched,
387 shambase::VecComponent<Tvec> theta_crit,
388 bool ordered_result,
389 bool allow_leaf_lowering) {
390 if (allow_leaf_lowering) {
391 return dtt_internal<true>(dev_sched, bvh, theta_crit, ordered_result);
392 } else {
393 return dtt_internal<false>(dev_sched, bvh, theta_crit, ordered_result);
394 }
395 }
396 };
397
398} // namespace shamtree::details
Dual tree traversal algorithm for Compressed Leaf Bounding Volume Hierarchies.
Fixed-size stack container for high-performance applications.
std::uint32_t u32
32 bit unsigned integer
A buffer allocated in USM (Unified Shared Memory)
T get_val_at_idx(size_t idx) const
Get the value at a given index in the buffer.
A Compressed Leaf Bounding Volume Hierarchy (CLBVH) for neighborhood queries.
shamtree::CLBVHObjectIterator< Tmorton, Tvec, dim > get_object_iterator() const
Retrieves an iterator for traversing objects in the BVH.
KarrasRadixTree structure
The tree structure.
void kernel_call(sham::DeviceQueue &q, RefIn in, RefOut in_out, u32 n, Functor &&func, SourceLocation &&callsite=SourceLocation{})
Submit a kernel to a SYCL queue.
sycl::buffer< T > scan_exclusive(sycl::queue &q, sycl::buffer< T > &buf1, u32 len)
Computes the exclusive sum of elements in a SYCL buffer.
Definition numeric.cpp:35
T & get_check_ref(const std::unique_ptr< T > &ptr, SourceLocation loc=SourceLocation())
Takes a std::unique_ptr and returns a reference to the object it holds. It throws a std::runtime_erro...
Definition memory.hpp:110
A class that references multiple buffers or similar objects.
Fixed-capacity stack container with compile-time size determination.
void push(const T &val)
Push an element onto the top of the stack.
T pop_ret()
Remove and return the top element from the stack.
constexpr bool is_not_empty() const
Check if the stack contains any elements.
Axis-Aligned bounding box.
Definition AABB.hpp:99
bool contains(AABB other) const noexcept
Check if AABB fully contains another AABB.
Definition AABB.hpp:244
This class is designed to traverse a BVH tree represented as a Compressed Leaf BVH (CLBVH) and a Karr...
Result structure for dual tree traversal operations.