Shamrock 2025.10.0
Astrophysical Code
Loading...
Searching...
No Matches
dtt_scan_multipass.hpp
Go to the documentation of this file.
1// -------------------------------------------------------//
2//
3// SHAMROCK code for hydrodynamics
4// Copyright (c) 2021-2026 Timothée David--Cléris <tim.shamrock@proton.me>
5// SPDX-License-Identifier: CeCILL Free Software License Agreement v2.1
6// Shamrock is licensed under the CeCILL 2.1 License, see LICENSE for more information
7//
8// -------------------------------------------------------//
9
10#pragma once
11
17
18#include "shambase/memory.hpp"
22#include "shambackends/vec.hpp"
23#include "shamcomm/logs.hpp"
24#include "shammath/AABB.hpp"
29
30namespace shamtree::details {
31
32 template<class Tmorton, class Tvec, u32 dim>
34
35 using Tscal = shambase::VecComponent<Tvec>;
36
37 inline static bool mac(shammath::AABB<Tvec> a, shammath::AABB<Tvec> b, Tscal theta_crit) {
38 return shamtree::details::mac(a, b, theta_crit);
39 }
40
41 template<bool allow_leaf_lowering>
42 inline static shamtree::DTTResult dtt_internal(
43 sham::DeviceScheduler_ptr dev_sched,
45 shambase::VecComponent<Tvec> theta_crit,
46 bool ordered_result) {
47 StackEntry stack_loc{};
48
49 auto q = shambase::get_check_ref(dev_sched).get_queue();
50
51 sham::DeviceBuffer<u32_2> node_interactions_m2l(0, dev_sched);
52 sham::DeviceBuffer<u32_2> node_interactions_p2p(0, dev_sched);
53
55 .node_interactions_m2l = std::move(node_interactions_m2l),
56 .node_interactions_p2p = std::move(node_interactions_p2p)};
57
58 auto add_ordering = [&]() {
59 if (ordered_result) {
60 auto offset_m2l = sham::DeviceBuffer<u32>(0, dev_sched);
61 auto offset_p2p = sham::DeviceBuffer<u32>(0, dev_sched);
62
63 shamtree::details::reorder_scan_dtt_result(
64 bvh.structure.get_total_cell_count(),
66 offset_m2l);
67
68 shamtree::details::reorder_scan_dtt_result(
69 bvh.structure.get_total_cell_count(),
71 offset_p2p);
72
74 .offset_m2l = std::move(offset_m2l), .offset_p2p = std::move(offset_p2p)};
75
76 result.ordered_result = std::move(ordering);
77 }
78 };
79
80 u32 max_cell_idx = bvh.structure.get_total_cell_count();
81
82 if (bvh.is_root_leaf()) {
84 result.node_interactions_p2p.set_val_at_idx(0, {0, 0});
85 add_ordering();
86 return result;
87 }
88
89 // we assume from this point that the root is not a leaf
90
91 sham::DeviceBuffer<u32_2> task_current(1, dev_sched);
92 task_current.set_val_at_idx(0, {0, 0});
93
94 u32 start_size = 1; // bvh.structure.get_total_cell_count();
95
96 sham::DeviceBuffer<u32> has_pushed_task(start_size, dev_sched);
97 sham::DeviceBuffer<u32_2> task_next(start_size, dev_sched);
98
99 sham::DeviceBuffer<u32> has_pushed_m2l(start_size, dev_sched);
100 sham::DeviceBuffer<u32_2> pushed_m2l(start_size, dev_sched);
101
102 sham::DeviceBuffer<u32> has_pushed_p2p(start_size, dev_sched);
103 sham::DeviceBuffer<u32_2> pushed_p2p(start_size, dev_sched);
104
105 auto resize_max = [](auto &buf, u32 sz) {
106 if (buf.get_size() < sz) {
107 buf.resize(sz);
108 }
109 };
110
111 while (task_current.get_size() > 0) {
112 u32 task_count = task_current.get_size();
113 shamlog_debug_ln("dtt_scan_multipass", "task_current.get_size() :", task_count);
114
115 // resizing BS
116 u32 has_pushed_task_sz = task_count + 1;
117 u32 task_next_sz = 4 * task_count;
118 u32 has_pushed_m2l_sz = task_count + 1;
119 u32 pushed_m2l_sz = task_count;
120 u32 has_pushed_p2p_sz = task_count + 1;
121 u32 pushed_p2p_sz = task_count;
122
123 resize_max(has_pushed_task, has_pushed_task_sz);
124 resize_max(task_next, task_next_sz);
125 resize_max(has_pushed_m2l, has_pushed_m2l_sz);
126 resize_max(pushed_m2l, pushed_m2l_sz);
127 resize_max(has_pushed_p2p, has_pushed_p2p_sz);
128 resize_max(pushed_p2p, pushed_p2p_sz);
129
130 has_pushed_task.fill(0, has_pushed_task_sz);
131 has_pushed_m2l.fill(0, has_pushed_m2l_sz);
132 has_pushed_p2p.fill(0, has_pushed_p2p_sz);
133
135 ObjectIterator obj_it = bvh.get_object_iterator();
136
137 using ObjItAcc = typename ObjectIterator::acc;
138
139 // the embarrassingly parallel bit
141 q,
142 sham::MultiRef{task_current, obj_it},
144 has_pushed_task,
145 task_next,
146 has_pushed_m2l,
147 pushed_m2l,
148 has_pushed_p2p,
149 pushed_p2p},
150 task_count,
151 [theta_crit](
152 u32 i,
153 const u32_2 *__restrict__ task_current,
154 ObjItAcc obj_it,
155 u32 *__restrict__ has_pushed_task,
156 u32_2 *__restrict__ task_next,
157 u32 *__restrict__ has_pushed_m2l,
158 u32_2 *__restrict__ pushed_m2l,
159 u32 *__restrict__ has_pushed_p2p,
160 u32_2 *__restrict__ pushed_p2p) {
161 u32_2 t = task_current[i];
162 u32 a = t.x();
163 u32 b = t.y();
164
165 shammath::AABB<Tvec> aabb_a = {
166 obj_it.tree_traverser.aabb_min[a], obj_it.tree_traverser.aabb_max[a]};
167 shammath::AABB<Tvec> aabb_b = {
168 obj_it.tree_traverser.aabb_min[b], obj_it.tree_traverser.aabb_max[b]};
169
170 bool crit = mac(aabb_a, aabb_b, theta_crit) == false;
171
172 if (crit) {
173 auto &ttrav = obj_it.tree_traverser.tree_traverser;
174
175 if constexpr (allow_leaf_lowering) {
176 bool is_a_leaf = ttrav.is_id_leaf(a);
177 bool is_b_leaf = ttrav.is_id_leaf(b);
178
179 if (is_a_leaf && is_b_leaf) {
180 pushed_p2p[i] = {a, b};
181 has_pushed_p2p[i] = 1;
182 } else {
183
184 u32 child_a_1 = (is_a_leaf) ? a : ttrav.get_left_child(a);
185 u32 child_a_2 = (is_a_leaf) ? a : ttrav.get_right_child(a);
186 u32 child_b_1 = (is_b_leaf) ? b : ttrav.get_left_child(b);
187 u32 child_b_2 = (is_b_leaf) ? b : ttrav.get_right_child(b);
188
189 bool run_a_1 = true;
190 bool run_a_2 = !is_a_leaf;
191 bool run_b_1 = true;
192 bool run_b_2 = !is_b_leaf;
193
194 u32 push_count = 0;
195
196 if (run_a_1 && run_b_1) {
197 task_next[i * 4 + push_count] = {child_a_1, child_b_1};
198 push_count++;
199 }
200 if (run_a_2 && run_b_1) {
201 task_next[i * 4 + push_count] = {child_a_2, child_b_1};
202 push_count++;
203 }
204 if (run_a_1 && run_b_2) {
205 task_next[i * 4 + push_count] = {child_a_1, child_b_2};
206 push_count++;
207 }
208 if (run_a_2 && run_b_2) {
209 task_next[i * 4 + push_count] = {child_a_2, child_b_2};
210 push_count++;
211 }
212 has_pushed_task[i] += push_count;
213 }
214
215 } else {
216 u32 child_a_1 = ttrav.get_left_child(a);
217 u32 child_a_2 = ttrav.get_right_child(a);
218 u32 child_b_1 = ttrav.get_left_child(b);
219 u32 child_b_2 = ttrav.get_right_child(b);
220
221 bool child_a_1_leaf = ttrav.is_id_leaf(child_a_1);
222 bool child_a_2_leaf = ttrav.is_id_leaf(child_a_2);
223 bool child_b_1_leaf = ttrav.is_id_leaf(child_b_1);
224 bool child_b_2_leaf = ttrav.is_id_leaf(child_b_2);
225
226 if (child_a_1_leaf || child_a_2_leaf || child_b_1_leaf
227 || child_b_2_leaf) {
228 pushed_p2p[i] = {a, b};
229 has_pushed_p2p[i] = 1;
230 } else {
231 task_next[i * 4 + 0] = {child_a_1, child_b_1};
232 task_next[i * 4 + 1] = {child_a_1, child_b_2};
233 task_next[i * 4 + 2] = {child_a_2, child_b_1};
234 task_next[i * 4 + 3] = {child_a_2, child_b_2};
235 has_pushed_task[i] += 4;
236 }
237 }
238
239 } else {
240 pushed_m2l[i] = {a, b};
241 has_pushed_m2l[i] = 1;
242 }
243 });
244
245// set to false to use standard scans instead of in place ones
246#if true
248 has_pushed_task, has_pushed_task_sz);
250 has_pushed_m2l, has_pushed_m2l_sz);
252 has_pushed_p2p, has_pushed_p2p_sz);
253
254#else
255 has_pushed_task = shamalgs::numeric::scan_exclusive(
256 dev_sched, has_pushed_task, has_pushed_task_sz);
257 has_pushed_m2l = shamalgs::numeric::scan_exclusive(
258 dev_sched, has_pushed_m2l, has_pushed_m2l_sz);
259 has_pushed_p2p = shamalgs::numeric::scan_exclusive(
260 dev_sched, has_pushed_p2p, has_pushed_p2p_sz);
261#endif
262 sham::DeviceBuffer<u32> &scan_task = (has_pushed_task);
263 sham::DeviceBuffer<u32> &scan_m2l = (has_pushed_m2l);
264 sham::DeviceBuffer<u32> &scan_p2p = (has_pushed_p2p);
265
266 // get the sizes of the result buffers before resizing
267 u32 res_sz_node_node = result.node_interactions_m2l.get_size();
268 u32 res_sz_leaf_leaf = result.node_interactions_p2p.get_size();
269
270 // get the resulting count from the main kernel
271 u32 count_task = scan_task.get_val_at_idx(has_pushed_task_sz - 1);
272 u32 count_m2l = scan_m2l.get_val_at_idx(has_pushed_m2l_sz - 1);
273 u32 count_p2p = scan_p2p.get_val_at_idx(has_pushed_p2p_sz - 1);
274
275 // expand the result buffers
276 result.node_interactions_m2l.expand(count_m2l);
277 result.node_interactions_p2p.expand(count_p2p);
278
279 // allocate space for the next pass
280 task_current.resize(count_task);
281
282 // 4 wide stream compaction
284 q,
285 sham::MultiRef{task_next, scan_task},
286 sham::MultiRef{task_current},
287 task_count,
288 [max_cell_idx](
289 u32 i,
290 const u32_2 *__restrict__ task_next,
291 const u32 *__restrict__ scan_task,
292 u32_2 *__restrict__ task_current) {
293 u32 scan_task_i = scan_task[i];
294 u32 scan_task_ip1 = scan_task[i + 1];
295 u32 delta = scan_task_ip1 - scan_task_i;
296 if (delta > 0) {
297 u32 idx = scan_task_i;
298
299 if constexpr (allow_leaf_lowering) {
300 for (u32 l = 0; l < delta; l++) {
301 SHAM_ASSERT(task_next[i * 4 + l].x() < max_cell_idx);
302 SHAM_ASSERT(task_next[i * 4 + l].y() < max_cell_idx);
303 task_current[idx + l] = task_next[i * 4 + l];
304 }
305 } else {
306 task_current[idx + 0] = task_next[i * 4 + 0];
307 task_current[idx + 1] = task_next[i * 4 + 1];
308 task_current[idx + 2] = task_next[i * 4 + 2];
309 task_current[idx + 3] = task_next[i * 4 + 3];
310 }
311 }
312 });
313
314 // stream compaction
316 q,
317 sham::MultiRef{pushed_m2l, scan_m2l},
319 task_count,
320 [res_sz_node_node](
321 u32 i,
322 const u32_2 *__restrict__ pushed_m2l,
323 const u32 *__restrict__ scan_m2l,
324 u32_2 *__restrict__ interacts_m2l) {
325 u32 scan_m2l_i = scan_m2l[i];
326 u32 scan_m2l_ip1 = scan_m2l[i + 1];
327 if (scan_m2l_ip1 - scan_m2l_i == 1) {
328 interacts_m2l[res_sz_node_node + scan_m2l_i] = pushed_m2l[i];
329 }
330 });
331
332 // stream compaction
334 q,
335 sham::MultiRef{pushed_p2p, scan_p2p},
337 task_count,
338 [res_sz_leaf_leaf](
339 u32 i,
340 const u32_2 *__restrict__ pushed_p2p,
341 const u32 *__restrict__ scan_p2p,
342 u32_2 *__restrict__ interact_p2p) {
343 u32 scan_p2p_i = scan_p2p[i];
344 u32 scan_p2p_ip1 = scan_p2p[i + 1];
345 if (scan_p2p_ip1 - scan_p2p_i == 1) {
346 interact_p2p[res_sz_leaf_leaf + scan_p2p_i] = pushed_p2p[i];
347 }
348 });
349 }
350
351 add_ordering();
352
353 return result;
354 }
355
356 inline static shamtree::DTTResult dtt(
357 sham::DeviceScheduler_ptr dev_sched,
359 shambase::VecComponent<Tvec> theta_crit,
360 bool ordered_result,
361 bool allow_leaf_lowering) {
362 if (allow_leaf_lowering) {
363 return dtt_internal<true>(dev_sched, bvh, theta_crit, ordered_result);
364 } else {
365 return dtt_internal<false>(dev_sched, bvh, theta_crit, ordered_result);
366 }
367 }
368 };
369} // namespace shamtree::details
Dual tree traversal algorithm for Compressed Leaf Bounding Volume Hierarchies.
std::uint32_t u32
32 bit unsigned integer
#define SHAM_ASSERT(x)
Shorthand for SHAM_ASSERT_NAMED without a message.
Definition assert.hpp:67
A buffer allocated in USM (Unified Shared Memory).
void resize(size_t new_size, bool keep_data=true)
Resizes the buffer to a given size.
void fill(T value, std::array< size_t, 2 > idx_range)
Fill a subpart of the buffer with a given value.
T get_val_at_idx(size_t idx) const
Get the value at a given index in the buffer.
size_t get_size() const
Gets the number of elements in the buffer.
void expand(u32 add_sz)
Expand the buffer by add_sz elements.
A Compressed Leaf Bounding Volume Hierarchy (CLBVH) for neighborhood queries.
shamtree::CLBVHObjectIterator< Tmorton, Tvec, dim > get_object_iterator() const
Retrieves an iterator for traversing objects in the BVH.
KarrasRadixTree structure
The tree structure.
bool is_root_leaf() const
is the root a leaf ?
void kernel_call(sham::DeviceQueue &q, RefIn in, RefOut in_out, u32 n, Functor &&func, SourceLocation &&callsite=SourceLocation{})
Submit a kernel to a SYCL queue.
sycl::buffer< T > scan_exclusive(sycl::queue &q, sycl::buffer< T > &buf1, u32 len)
Computes the exclusive sum of elements in a SYCL buffer.
Definition numeric.cpp:35
void scan_exclusive_sum_in_place(sham::DeviceBuffer< T > &buf1, u32 len)
Compute exclusive prefix sum in-place on a device buffer.
T & get_check_ref(const std::unique_ptr< T > &ptr, SourceLocation loc=SourceLocation())
Takes a std::unique_ptr and returns a reference to the object it holds. It throws a std::runtime_erro...
Definition memory.hpp:110
In-place exclusive scan (prefix sum) algorithm for device buffers.
shambase::details::BasicStackEntry StackEntry
Alias for shambase::details::BasicStackEntry.
A class that references multiple buffers or similar objects.
Axis-Aligned bounding box.
Definition AABB.hpp:99
This class is designed to traverse a BVH tree represented as a Compressed Leaf BVH (CLBVH) and a Karr...
Result structure for dual tree traversal operations.
sham::DeviceBuffer< u32_2 > node_interactions_m2l
Pairs of nodes that interact using M2L interactions.
sham::DeviceBuffer< u32_2 > node_interactions_p2p
Pairs of nodes that interact using P2P interactions.