Shamrock 2025.10.0
Astrophysical Code
Loading...
Searching...
No Matches
dtt_scan_multipass.hpp
Go to the documentation of this file.
1// -------------------------------------------------------//
2//
3// SHAMROCK code for hydrodynamics
4// Copyright (c) 2021-2026 Timothée David--Cléris <tim.shamrock@proton.me>
5// SPDX-License-Identifier: CeCILL Free Software License Agreement v2.1
6// Shamrock is licensed under the CeCILL 2.1 License, see LICENSE for more information
7//
8// -------------------------------------------------------//
9
10#pragma once
11
18#include "shambase/memory.hpp"
22#include "shambackends/vec.hpp"
23#include "shamcomm/logs.hpp"
24#include "shammath/AABB.hpp"
29
30namespace shamtree::details {
31
32 template<class Tmorton, class Tvec, u32 dim>
34
35 using Tscal = shambase::VecComponent<Tvec>;
36
37 inline static bool mac(shammath::AABB<Tvec> a, shammath::AABB<Tvec> b, Tscal theta_crit) {
38 return shamtree::details::mac(a, b, theta_crit);
39 }
40
41 template<bool allow_leaf_lowering>
42 inline static shamtree::DTTResult dtt_internal(
43 sham::DeviceScheduler_ptr dev_sched,
45 shambase::VecComponent<Tvec> theta_crit,
46 bool ordered_result) {
47 StackEntry stack_loc{};
48
49 auto q = shambase::get_check_ref(dev_sched).get_queue();
50
51 sham::DeviceBuffer<u32_2> node_interactions_m2l(0, dev_sched);
52 sham::DeviceBuffer<u32_2> node_interactions_p2p(0, dev_sched);
53
55 std::move(node_interactions_m2l), std::move(node_interactions_p2p)};
56
57 auto add_ordering = [&]() {
58 if (ordered_result) {
59 auto offset_m2l = sham::DeviceBuffer<u32>(0, dev_sched);
60 auto offset_p2p = sham::DeviceBuffer<u32>(0, dev_sched);
61
62 shamtree::details::reorder_scan_dtt_result(
63 bvh.structure.get_total_cell_count(),
64 result.node_interactions_m2l,
65 offset_m2l);
66
67 shamtree::details::reorder_scan_dtt_result(
68 bvh.structure.get_total_cell_count(),
69 result.node_interactions_p2p,
70 offset_p2p);
71
72 DTTResult::OrderedResult ordering{std::move(offset_m2l), std::move(offset_p2p)};
73
74 result.ordered_result = std::move(ordering);
75 }
76 };
77
78 u32 max_cell_idx = bvh.structure.get_total_cell_count();
79
80 if (bvh.is_root_leaf()) {
81 result.node_interactions_p2p.resize(1);
82 result.node_interactions_p2p.set_val_at_idx(0, {0, 0});
83 add_ordering();
84 return result;
85 }
86
87 // we assume from this point that the root is not a leaf
88
89 sham::DeviceBuffer<u32_2> task_current(1, dev_sched);
90 task_current.set_val_at_idx(0, {0, 0});
91
92 u32 start_size = 1; // bvh.structure.get_total_cell_count();
93
94 sham::DeviceBuffer<u32> has_pushed_task(start_size, dev_sched);
95 sham::DeviceBuffer<u32_2> task_next(start_size, dev_sched);
96
97 sham::DeviceBuffer<u32> has_pushed_m2l(start_size, dev_sched);
98 sham::DeviceBuffer<u32_2> pushed_m2l(start_size, dev_sched);
99
100 sham::DeviceBuffer<u32> has_pushed_p2p(start_size, dev_sched);
101 sham::DeviceBuffer<u32_2> pushed_p2p(start_size, dev_sched);
102
103 auto resize_max = [](auto &buf, u32 sz) {
104 if (buf.get_size() < sz) {
105 buf.resize(sz);
106 }
107 };
108
109 while (task_current.get_size() > 0) {
110 u32 task_count = task_current.get_size();
111 shamlog_debug_ln("dtt_scan_multipass", "task_current.get_size() :", task_count);
112
113 // resizing BS
114 u32 has_pushed_task_sz = task_count + 1;
115 u32 task_next_sz = 4 * task_count;
116 u32 has_pushed_m2l_sz = task_count + 1;
117 u32 pushed_m2l_sz = task_count;
118 u32 has_pushed_p2p_sz = task_count + 1;
119 u32 pushed_p2p_sz = task_count;
120
121 resize_max(has_pushed_task, has_pushed_task_sz);
122 resize_max(task_next, task_next_sz);
123 resize_max(has_pushed_m2l, has_pushed_m2l_sz);
124 resize_max(pushed_m2l, pushed_m2l_sz);
125 resize_max(has_pushed_p2p, has_pushed_p2p_sz);
126 resize_max(pushed_p2p, pushed_p2p_sz);
127
128 has_pushed_task.fill(0, has_pushed_task_sz);
129 has_pushed_m2l.fill(0, has_pushed_m2l_sz);
130 has_pushed_p2p.fill(0, has_pushed_p2p_sz);
131
133 ObjectIterator obj_it = bvh.get_object_iterator();
134
135 using ObjItAcc = typename ObjectIterator::acc;
136
137 // the embarrassingly parallel bit
139 q,
140 sham::MultiRef{task_current, obj_it},
142 has_pushed_task,
143 task_next,
144 has_pushed_m2l,
145 pushed_m2l,
146 has_pushed_p2p,
147 pushed_p2p},
148 task_count,
149 [theta_crit](
150 u32 i,
151 const u32_2 *__restrict__ task_current,
152 ObjItAcc obj_it,
153 u32 *__restrict__ has_pushed_task,
154 u32_2 *__restrict__ task_next,
155 u32 *__restrict__ has_pushed_m2l,
156 u32_2 *__restrict__ pushed_m2l,
157 u32 *__restrict__ has_pushed_p2p,
158 u32_2 *__restrict__ pushed_p2p) {
159 u32_2 t = task_current[i];
160 u32 a = t.x();
161 u32 b = t.y();
162
163 shammath::AABB<Tvec> aabb_a = {
164 obj_it.tree_traverser.aabb_min[a], obj_it.tree_traverser.aabb_max[a]};
165 shammath::AABB<Tvec> aabb_b = {
166 obj_it.tree_traverser.aabb_min[b], obj_it.tree_traverser.aabb_max[b]};
167
168 bool crit = mac(aabb_a, aabb_b, theta_crit) == false;
169
170 if (crit) {
171 auto &ttrav = obj_it.tree_traverser.tree_traverser;
172
173 if constexpr (allow_leaf_lowering) {
174 bool is_a_leaf = ttrav.is_id_leaf(a);
175 bool is_b_leaf = ttrav.is_id_leaf(b);
176
177 if (is_a_leaf && is_b_leaf) {
178 pushed_p2p[i] = {a, b};
179 has_pushed_p2p[i] = 1;
180 } else {
181
182 u32 child_a_1 = (is_a_leaf) ? a : ttrav.get_left_child(a);
183 u32 child_a_2 = (is_a_leaf) ? a : ttrav.get_right_child(a);
184 u32 child_b_1 = (is_b_leaf) ? b : ttrav.get_left_child(b);
185 u32 child_b_2 = (is_b_leaf) ? b : ttrav.get_right_child(b);
186
187 bool run_a_1 = true;
188 bool run_a_2 = !is_a_leaf;
189 bool run_b_1 = true;
190 bool run_b_2 = !is_b_leaf;
191
192 u32 push_count = 0;
193
194 if (run_a_1 && run_b_1) {
195 task_next[i * 4 + push_count] = {child_a_1, child_b_1};
196 push_count++;
197 }
198 if (run_a_2 && run_b_1) {
199 task_next[i * 4 + push_count] = {child_a_2, child_b_1};
200 push_count++;
201 }
202 if (run_a_1 && run_b_2) {
203 task_next[i * 4 + push_count] = {child_a_1, child_b_2};
204 push_count++;
205 }
206 if (run_a_2 && run_b_2) {
207 task_next[i * 4 + push_count] = {child_a_2, child_b_2};
208 push_count++;
209 }
210 has_pushed_task[i] += push_count;
211 }
212
213 } else {
214 u32 child_a_1 = ttrav.get_left_child(a);
215 u32 child_a_2 = ttrav.get_right_child(a);
216 u32 child_b_1 = ttrav.get_left_child(b);
217 u32 child_b_2 = ttrav.get_right_child(b);
218
219 bool child_a_1_leaf = ttrav.is_id_leaf(child_a_1);
220 bool child_a_2_leaf = ttrav.is_id_leaf(child_a_2);
221 bool child_b_1_leaf = ttrav.is_id_leaf(child_b_1);
222 bool child_b_2_leaf = ttrav.is_id_leaf(child_b_2);
223
224 if (child_a_1_leaf || child_a_2_leaf || child_b_1_leaf
225 || child_b_2_leaf) {
226 pushed_p2p[i] = {a, b};
227 has_pushed_p2p[i] = 1;
228 } else {
229 task_next[i * 4 + 0] = {child_a_1, child_b_1};
230 task_next[i * 4 + 1] = {child_a_1, child_b_2};
231 task_next[i * 4 + 2] = {child_a_2, child_b_1};
232 task_next[i * 4 + 3] = {child_a_2, child_b_2};
233 has_pushed_task[i] += 4;
234 }
235 }
236
237 } else {
238 pushed_m2l[i] = {a, b};
239 has_pushed_m2l[i] = 1;
240 }
241 });
242
243// set to false to use standard scans instead of in place ones
244#if true
246 has_pushed_task, has_pushed_task_sz);
248 has_pushed_m2l, has_pushed_m2l_sz);
250 has_pushed_p2p, has_pushed_p2p_sz);
251
252#else
253 has_pushed_task = shamalgs::numeric::scan_exclusive(
254 dev_sched, has_pushed_task, has_pushed_task_sz);
255 has_pushed_m2l = shamalgs::numeric::scan_exclusive(
256 dev_sched, has_pushed_m2l, has_pushed_m2l_sz);
257 has_pushed_p2p = shamalgs::numeric::scan_exclusive(
258 dev_sched, has_pushed_p2p, has_pushed_p2p_sz);
259#endif
260 sham::DeviceBuffer<u32> &scan_task = (has_pushed_task);
261 sham::DeviceBuffer<u32> &scan_m2l = (has_pushed_m2l);
262 sham::DeviceBuffer<u32> &scan_p2p = (has_pushed_p2p);
263
264 // get the sizes of the result buffers before resizing
265 u32 res_sz_node_node = result.node_interactions_m2l.get_size();
266 u32 res_sz_leaf_leaf = result.node_interactions_p2p.get_size();
267
268 // get the resulting count from the main kernel
269 u32 count_task = scan_task.get_val_at_idx(has_pushed_task_sz - 1);
270 u32 count_m2l = scan_m2l.get_val_at_idx(has_pushed_m2l_sz - 1);
271 u32 count_p2p = scan_p2p.get_val_at_idx(has_pushed_p2p_sz - 1);
272
273 // expand the result buffers
274 result.node_interactions_m2l.expand(count_m2l);
275 result.node_interactions_p2p.expand(count_p2p);
276
277 // allocate space for the next pass
278 task_current.resize(count_task);
279
280 // 4 wide stream compaction
282 q,
283 sham::MultiRef{task_next, scan_task},
284 sham::MultiRef{task_current},
285 task_count,
286 [max_cell_idx](
287 u32 i,
288 const u32_2 *__restrict__ task_next,
289 const u32 *__restrict__ scan_task,
290 u32_2 *__restrict__ task_current) {
291 u32 scan_task_i = scan_task[i];
292 u32 scan_task_ip1 = scan_task[i + 1];
293 u32 delta = scan_task_ip1 - scan_task_i;
294 if (delta > 0) {
295 u32 idx = scan_task_i;
296
297 if constexpr (allow_leaf_lowering) {
298 for (u32 l = 0; l < delta; l++) {
299 SHAM_ASSERT(task_next[i * 4 + l].x() < max_cell_idx);
300 SHAM_ASSERT(task_next[i * 4 + l].y() < max_cell_idx);
301 task_current[idx + l] = task_next[i * 4 + l];
302 }
303 } else {
304 task_current[idx + 0] = task_next[i * 4 + 0];
305 task_current[idx + 1] = task_next[i * 4 + 1];
306 task_current[idx + 2] = task_next[i * 4 + 2];
307 task_current[idx + 3] = task_next[i * 4 + 3];
308 }
309 }
310 });
311
312 // stream compaction
314 q,
315 sham::MultiRef{pushed_m2l, scan_m2l},
316 sham::MultiRef{result.node_interactions_m2l},
317 task_count,
318 [res_sz_node_node](
319 u32 i,
320 const u32_2 *__restrict__ pushed_m2l,
321 const u32 *__restrict__ scan_m2l,
322 u32_2 *__restrict__ interacts_m2l) {
323 u32 scan_m2l_i = scan_m2l[i];
324 u32 scan_m2l_ip1 = scan_m2l[i + 1];
325 if (scan_m2l_ip1 - scan_m2l_i == 1) {
326 interacts_m2l[res_sz_node_node + scan_m2l_i] = pushed_m2l[i];
327 }
328 });
329
330 // stream compaction
332 q,
333 sham::MultiRef{pushed_p2p, scan_p2p},
334 sham::MultiRef{result.node_interactions_p2p},
335 task_count,
336 [res_sz_leaf_leaf](
337 u32 i,
338 const u32_2 *__restrict__ pushed_p2p,
339 const u32 *__restrict__ scan_p2p,
340 u32_2 *__restrict__ interact_p2p) {
341 u32 scan_p2p_i = scan_p2p[i];
342 u32 scan_p2p_ip1 = scan_p2p[i + 1];
343 if (scan_p2p_ip1 - scan_p2p_i == 1) {
344 interact_p2p[res_sz_leaf_leaf + scan_p2p_i] = pushed_p2p[i];
345 }
346 });
347 }
348
349 add_ordering();
350
351 return result;
352 }
353
354 inline static shamtree::DTTResult dtt(
355 sham::DeviceScheduler_ptr dev_sched,
357 shambase::VecComponent<Tvec> theta_crit,
358 bool ordered_result,
359 bool allow_leaf_lowering) {
360 if (allow_leaf_lowering) {
361 return dtt_internal<true>(dev_sched, bvh, theta_crit, ordered_result);
362 } else {
363 return dtt_internal<false>(dev_sched, bvh, theta_crit, ordered_result);
364 }
365 }
366 };
367} // namespace shamtree::details
Dual tree traversal algorithm for Compressed Leaf Bounding Volume Hierarchies.
std::uint32_t u32
32 bit unsigned integer
#define SHAM_ASSERT(x)
Shorthand for SHAM_ASSERT_NAMED without a message.
Definition assert.hpp:67
A buffer allocated in USM (Unified Shared Memory)
void fill(T value, std::array< size_t, 2 > idx_range)
Fill a subpart of the buffer with a given value.
T get_val_at_idx(size_t idx) const
Get the value at a given index in the buffer.
size_t get_size() const
Gets the number of elements in the buffer.
A Compressed Leaf Bounding Volume Hierarchy (CLBVH) for neighborhood queries.
shamtree::CLBVHObjectIterator< Tmorton, Tvec, dim > get_object_iterator() const
Retrieves an iterator for traversing objects in the BVH.
KarrasRadixTree structure
The tree structure.
bool is_root_leaf() const
is the root a leaf ?
void kernel_call(sham::DeviceQueue &q, RefIn in, RefOut in_out, u32 n, Functor &&func, SourceLocation &&callsite=SourceLocation{})
Submit a kernel to a SYCL queue.
sycl::buffer< T > scan_exclusive(sycl::queue &q, sycl::buffer< T > &buf1, u32 len)
Computes the exclusive sum of elements in a SYCL buffer.
Definition numeric.cpp:35
void scan_exclusive_sum_in_place(sham::DeviceBuffer< T > &buf1, u32 len)
Compute exclusive prefix sum in-place on a device buffer.
T & get_check_ref(const std::unique_ptr< T > &ptr, SourceLocation loc=SourceLocation())
Takes a std::unique_ptr and returns a reference to the object it holds. It throws a std::runtime_erro...
Definition memory.hpp:110
In-place exclusive scan (prefix sum) algorithm for device buffers.
A class that references multiple buffers or similar objects.
Axis-Aligned bounding box.
Definition AABB.hpp:99
This class is designed to traverse a BVH tree represented as a Compressed Leaf BVH (CLBVH) and a Karr...
Result structure for dual tree traversal operations.