Shamrock 2025.10.0
Astrophysical Code
Loading...
Searching...
No Matches
TreeTraversal.hpp
Go to the documentation of this file.
1// -------------------------------------------------------//
2//
3// SHAMROCK code for hydrodynamics
4// Copyright (c) 2021-2026 Timothée David--Cléris <tim.shamrock@proton.me>
5// SPDX-License-Identifier: CeCILL Free Software License Agreement v2.1
6// Shamrock is licensed under the CeCILL 2.1 License, see LICENSE for more information
7//
8// -------------------------------------------------------//
9
10#pragma once
11
18#include "shamalgs/numeric.hpp"
19#include "shambackends/sycl.hpp"
21
22namespace shamrock::tree {
23
24 template<class u_morton, class vec>
26
27 sycl::accessor<u32, 1, sycl::access::mode::read, sycl::target::device> particle_index_map;
28 sycl::accessor<u32, 1, sycl::access::mode::read, sycl::target::device> cell_index_map;
29 sycl::accessor<u32, 1, sycl::access::mode::read, sycl::target::device> rchild_id;
30 sycl::accessor<u32, 1, sycl::access::mode::read, sycl::target::device> lchild_id;
31 sycl::accessor<u8, 1, sycl::access::mode::read, sycl::target::device> rchild_flag;
32 sycl::accessor<u8, 1, sycl::access::mode::read, sycl::target::device> lchild_flag;
33 sycl::accessor<vec, 1, sycl::access::mode::read, sycl::target::device> pos_min_cell;
34 sycl::accessor<vec, 1, sycl::access::mode::read, sycl::target::device> pos_max_cell;
35
36 static constexpr u32 tree_depth = RadixTree<u_morton, vec>::tree_depth;
37 static constexpr u32 _nindex = 4294967295;
38
39 u32 leaf_offset;
40
41 public:
42 // clang-format off
43 ObjectIterator(const RadixTree< u_morton, vec> & rtree,sycl::handler & cgh):
44 particle_index_map{shambase::get_check_ref(rtree.tree_morton_codes.buf_particle_index_map), cgh,sycl::read_only},
45 cell_index_map{shambase::get_check_ref(rtree.tree_reduced_morton_codes.buf_reduc_index_map), cgh,sycl::read_only},
46 rchild_id {shambase::get_check_ref(rtree.tree_struct.buf_rchild_id) , cgh,sycl::read_only},
47 lchild_id {shambase::get_check_ref(rtree.tree_struct.buf_lchild_id) , cgh,sycl::read_only},
48 rchild_flag {shambase::get_check_ref(rtree.tree_struct.buf_rchild_flag), cgh,sycl::read_only},
49 lchild_flag {shambase::get_check_ref(rtree.tree_struct.buf_lchild_flag), cgh,sycl::read_only},
50 pos_min_cell {shambase::get_check_ref(rtree.tree_cell_ranges.buf_pos_min_cell_flt), cgh,sycl::read_only},
51 pos_max_cell {shambase::get_check_ref(rtree.tree_cell_ranges.buf_pos_max_cell_flt), cgh,sycl::read_only},
52 leaf_offset (rtree.tree_struct.internal_cell_count)
53 {}
54 // clang-format on
55
56 template<class Functor_iter>
57 inline void iter_object_in_cell(const u32 &cell_id, Functor_iter &&func_it) const {
58 // loop on particle indexes
59 uint min_ids = cell_index_map[cell_id - leaf_offset];
60 uint max_ids = cell_index_map[cell_id + 1 - leaf_offset];
61
62 for (unsigned int id_s = min_ids; id_s < max_ids; id_s++) {
63
64 // recover old index before morton sort
65 uint id_b = particle_index_map[id_s];
66
67 // iteration function
68 func_it(id_b);
69 }
70 }
71
72 template<class Functor_int_cd, class Functor_iter, class Functor_iter_excl>
73 inline void rtree_for(
74 Functor_int_cd &&func_int_cd,
75 Functor_iter &&func_it,
76 Functor_iter_excl &&func_excl) const {
77 u32 stack_cursor = tree_depth - 1;
78 std::array<u32, tree_depth> id_stack;
79 id_stack[stack_cursor] = 0;
80
81 while (stack_cursor < tree_depth) {
82
83 u32 current_node_id = id_stack[stack_cursor];
84 id_stack[stack_cursor] = _nindex;
85 stack_cursor++;
86
87 bool cur_id_valid = func_int_cd(
88 current_node_id, pos_min_cell[current_node_id], pos_max_cell[current_node_id]);
89
90 if (cur_id_valid) {
91
92 // leaf and cell can interact
93 if (current_node_id >= leaf_offset) {
94
95 iter_object_in_cell(current_node_id, func_it);
96
97 // can interact not leaf => stack
98 } else {
99
100 u32 lid = lchild_id[current_node_id]
101 + leaf_offset * lchild_flag[current_node_id];
102 u32 rid = rchild_id[current_node_id]
103 + leaf_offset * rchild_flag[current_node_id];
104
105 id_stack[stack_cursor - 1] = rid;
106 stack_cursor--;
107
108 id_stack[stack_cursor - 1] = lid;
109 stack_cursor--;
110 }
111 } else {
112 // grav
113 func_excl(current_node_id);
114 }
115 }
116 }
117
118 template<class Functor_int_cd, class Functor_iter>
119 inline void rtree_for(Functor_int_cd &&func_int_cd, Functor_iter &&func_it) const {
120 rtree_for(
121 std::forward<Functor_int_cd>(func_int_cd),
122 std::forward<Functor_iter>(func_it),
123 [](u32) {});
124 }
125 };
126
127 template<class u_morton, class vec>
129
130 sycl::accessor<u32, 1, sycl::access::mode::read, sycl::target::device> rchild_id;
131 sycl::accessor<u32, 1, sycl::access::mode::read, sycl::target::device> lchild_id;
132 sycl::accessor<u8, 1, sycl::access::mode::read, sycl::target::device> rchild_flag;
133 sycl::accessor<u8, 1, sycl::access::mode::read, sycl::target::device> lchild_flag;
134
135 public:
136 sycl::accessor<vec, 1, sycl::access::mode::read, sycl::target::device> pos_min_cell;
137 sycl::accessor<vec, 1, sycl::access::mode::read, sycl::target::device> pos_max_cell;
138
139 private:
140 static constexpr u32 tree_depth = RadixTree<u_morton, vec>::tree_depth;
141 static constexpr u32 _nindex = 4294967295;
142
143 u32 leaf_offset;
144
145 public:
146 // clang-format off
147 LeafIterator(RadixTree< u_morton, vec> & rtree,sycl::handler & cgh):
148 rchild_id {shambase::get_check_ref(rtree.tree_struct.buf_rchild_id) , cgh,sycl::read_only},
149 lchild_id {shambase::get_check_ref(rtree.tree_struct.buf_lchild_id) , cgh,sycl::read_only},
150 rchild_flag {shambase::get_check_ref(rtree.tree_struct.buf_rchild_flag), cgh,sycl::read_only},
151 lchild_flag {shambase::get_check_ref(rtree.tree_struct.buf_lchild_flag), cgh,sycl::read_only},
152 pos_min_cell {shambase::get_check_ref(rtree.tree_cell_ranges.buf_pos_min_cell_flt), cgh,sycl::read_only},
153 pos_max_cell {shambase::get_check_ref(rtree.tree_cell_ranges.buf_pos_max_cell_flt), cgh,sycl::read_only},
154 leaf_offset (rtree.tree_struct.internal_cell_count)
155 {}
156 // clang-format on
157
158 template<class Functor_int_cd, class Functor_iter, class Functor_iter_excl>
159 inline void rtree_for(
160 Functor_int_cd &&func_int_cd,
161 Functor_iter &&func_it,
162 Functor_iter_excl &&func_excl) const {
163 u32 stack_cursor = tree_depth - 1;
164 std::array<u32, tree_depth> id_stack;
165 id_stack[stack_cursor] = 0;
166
167 while (stack_cursor < tree_depth) {
168
169 u32 current_node_id = id_stack[stack_cursor];
170 id_stack[stack_cursor] = _nindex;
171 stack_cursor++;
172
173 bool cur_id_valid = func_int_cd(
174 current_node_id, pos_min_cell[current_node_id], pos_max_cell[current_node_id]);
175
176 if (cur_id_valid) {
177
178 // leaf and cell can interact
179 if (current_node_id >= leaf_offset) {
180
181 func_it(current_node_id);
182
183 // can interact not leaf => stack
184 } else {
185
186 u32 lid = lchild_id[current_node_id]
187 + leaf_offset * lchild_flag[current_node_id];
188 u32 rid = rchild_id[current_node_id]
189 + leaf_offset * rchild_flag[current_node_id];
190
191 id_stack[stack_cursor - 1] = rid;
192 stack_cursor--;
193
194 id_stack[stack_cursor - 1] = lid;
195 stack_cursor--;
196 }
197 } else {
198 // grav
199 func_excl(current_node_id);
200 }
201 }
202 }
203
204 template<class Functor_int_cd, class Functor_iter>
205 inline void rtree_for(Functor_int_cd &&func_int_cd, Functor_iter &&func_it) const {
206 rtree_for(
207 std::forward<Functor_int_cd>(func_int_cd),
208 std::forward<Functor_iter>(func_it),
209 [](u32) {});
210 }
211 };
212
213 template<class u_morton, class vec>
215
216 sycl::accessor<u32, 1, sycl::access::mode::read, sycl::target::device> rchild_id;
217 sycl::accessor<u32, 1, sycl::access::mode::read, sycl::target::device> lchild_id;
218 sycl::accessor<u8, 1, sycl::access::mode::read, sycl::target::device> rchild_flag;
219 sycl::accessor<u8, 1, sycl::access::mode::read, sycl::target::device> lchild_flag;
220
221 sycl::accessor<u_morton, 1, sycl::access::mode::read, sycl::target::device> tree_morton;
222
223 private:
224 static constexpr u32 tree_depth = RadixTree<u_morton, vec>::tree_depth;
225 static constexpr u32 _nindex = 4294967295;
226
227 u32 leaf_offset;
228
229 public:
230 // clang-format off
231 LeafRadixFinder(RadixTree< u_morton, vec> & rtree,sycl::handler & cgh):
232 rchild_id {shambase::get_check_ref(rtree.tree_struct.buf_rchild_id) , cgh,sycl::read_only},
233 lchild_id {shambase::get_check_ref(rtree.tree_struct.buf_lchild_id) , cgh,sycl::read_only},
234 rchild_flag {shambase::get_check_ref(rtree.tree_struct.buf_rchild_flag), cgh,sycl::read_only},
235 lchild_flag {shambase::get_check_ref(rtree.tree_struct.buf_lchild_flag), cgh,sycl::read_only},
236 tree_morton {shambase::get_check_ref(rtree.tree_reduced_morton_codes.buf_tree_morton), cgh,sycl::read_only},
237 leaf_offset (rtree.tree_struct.internal_cell_count)
238 {}
239 // clang-format on
240
247 inline u32 identify_cell(u_morton morton_code) const {
248 u32 current_node_id = 0;
249
250 for (u32 level = 0; level < tree_depth; level++) {
251
252 u32 lid = lchild_id[current_node_id];
253 u32 rid = rchild_id[current_node_id];
254 u32 lflag = lchild_flag[current_node_id];
255 u32 rflag = rchild_flag[current_node_id];
256
257 u_morton m_l = tree_morton[lid];
258 u_morton m_r = tree_morton[rid];
259
260 u32 affinity_l = sham::clz_xor(morton_code, m_l);
261 u32 affinity_r = sham::clz_xor(morton_code, m_r);
262
263 u32 next_id = (affinity_l > affinity_r) ? lid : rid;
264 u32 next_flag = (affinity_l > affinity_r) ? lflag : rflag;
265
266 if (next_flag == 1) {
267 return next_id;
268 }
269
270 current_node_id = next_id;
271 }
272
273 return u32_max;
274 }
275 };
276
277 class LeafCache {
278 public:
279 sham::DeviceBuffer<u32> cnt_neigh;
280 sham::DeviceBuffer<u32> scanned_cnt;
281 u32 sum_neigh_cnt;
282 sham::DeviceBuffer<u32> index_neigh_map;
283
284 struct ptrs {
285 const u32 *cnt_neigh;
286 const u32 *scanned_cnt;
287 const u32 *index_neigh_map;
288 };
289
290 ptrs get_read_access(sham::EventList &depends_list) {
291 return ptrs{
292 cnt_neigh.get_read_access(depends_list),
293 scanned_cnt.get_read_access(depends_list),
294 index_neigh_map.get_read_access(depends_list)};
295 }
296
297 void complete_event_state(sham::EventList &resulting_events) {
298 cnt_neigh.complete_event_state(resulting_events);
299 scanned_cnt.complete_event_state(resulting_events);
300 index_neigh_map.complete_event_state(resulting_events);
301 }
302 };
303
305
306 const u32 *neigh_cnt;
307 const u32 *table_neigh_offset;
308 const u32 *table_neigh;
309
310 sycl::accessor<u32, 1, sycl::access::mode::read, sycl::target::device> cell_owner;
311
312 sycl::accessor<u32, 1, sycl::access::mode::read, sycl::target::device> particle_index_map;
313 sycl::accessor<u32, 1, sycl::access::mode::read, sycl::target::device> cell_index_map;
314
315 u32 leaf_offset;
316
317 public:
318 // clang-format off
319 template<class u_morton, class vec>
320 LeafCacheObjectIterator(RadixTree< u_morton, vec> & rtree,sycl::buffer<u32> & ownerships, LeafCache::ptrs & cache,sycl::handler & cgh):
321 particle_index_map{shambase::get_check_ref(rtree.tree_morton_codes.buf_particle_index_map), cgh,sycl::read_only},
322 cell_index_map{shambase::get_check_ref(rtree.tree_reduced_morton_codes.buf_reduc_index_map), cgh,sycl::read_only},
323 neigh_cnt {cache.cnt_neigh },
324 table_neigh_offset {cache.scanned_cnt },
325 table_neigh {cache.index_neigh_map },
326 cell_owner {ownerships ,cgh,sycl::read_only},
327 leaf_offset (rtree.tree_struct.internal_cell_count)
328 {}
329 // clang-format on
330
331 template<class Functor_iter>
332 inline void iter_object_in_cell(const u32 &cell_id, Functor_iter &&func_it) const {
333 // loop on particle indexes
334 uint min_ids = cell_index_map[cell_id];
335 uint max_ids = cell_index_map[cell_id + 1];
336
337 for (unsigned int id_s = min_ids; id_s < max_ids; id_s++) {
338
339 // recover old index before morton sort
340 uint id_b = particle_index_map[id_s];
341
342 // iteration function
343 func_it(id_b);
344 }
345 }
346
347 template<class Functor_iter>
348 inline void for_each_object(u32 idx, Functor_iter &&func_it) const {
349
350 u32 leaf_cell_owner = cell_owner[idx];
351 u32 cnt = neigh_cnt[leaf_cell_owner];
352 u32 offset_start = table_neigh_offset[leaf_cell_owner];
353 u32 last_idx = offset_start + cnt;
354
355 for (u32 i = offset_start; i < last_idx; i++) {
356 iter_object_in_cell(table_neigh[i] - leaf_offset, func_it);
357 }
358 }
359 };
360
361 struct ObjectCache;
362
364 std::vector<u32> cnt_neigh;
365 std::vector<u32> scanned_cnt;
366 u32 sum_neigh_cnt;
367 std::vector<u32> index_neigh_map;
368
369 inline u64 get_memsize() {
370 return (cnt_neigh.size() + scanned_cnt.size() + index_neigh_map.size() + 1)
371 * sizeof(u32);
372 }
373 };
374
375 struct ObjectCache {
376 sham::DeviceBuffer<u32> cnt_neigh;
377 sham::DeviceBuffer<u32> scanned_cnt;
378 u32 sum_neigh_cnt;
379 sham::DeviceBuffer<u32> index_neigh_map;
380
381 inline u64 get_memsize() {
382 return cnt_neigh.get_mem_usage() + scanned_cnt.get_mem_usage()
383 + index_neigh_map.get_mem_usage() + sizeof(u32);
384 }
385
386 inline HostObjectCache copy_to_host() {
387 return HostObjectCache{
388 cnt_neigh.copy_to_stdvec(),
389 scanned_cnt.copy_to_stdvec(),
390 sum_neigh_cnt,
391 index_neigh_map.copy_to_stdvec(),
392 };
393 }
394
395 inline static ObjectCache build_from_host(HostObjectCache &cache) {
396 sham::DeviceBuffer<u32> cnt_neigh(
397 cache.cnt_neigh.size(), shamsys::instance::get_compute_scheduler_ptr());
398 sham::DeviceBuffer<u32> scanned_cnt(
399 cache.scanned_cnt.size(), shamsys::instance::get_compute_scheduler_ptr());
400 u32 sum_neigh_cnt = cache.sum_neigh_cnt;
401 sham::DeviceBuffer<u32> index_neigh_map(
402 cache.index_neigh_map.size(), shamsys::instance::get_compute_scheduler_ptr());
403
404 cnt_neigh.copy_from_stdvec(cache.cnt_neigh);
405 scanned_cnt.copy_from_stdvec(cache.scanned_cnt);
406 index_neigh_map.copy_from_stdvec(cache.index_neigh_map);
407
408 return ObjectCache{
409 std::move(cnt_neigh),
410 std::move(scanned_cnt),
411 sum_neigh_cnt,
412 std::move(index_neigh_map),
413 };
414 }
415
416 struct ptrs_read {
417 const u32 *cnt_neigh;
418 const u32 *scanned_cnt;
419 const u32 *index_neigh_map;
420 };
421
422 struct ptrs {
423 u32 *cnt_neigh;
424 u32 *scanned_cnt;
425 u32 *index_neigh_map;
426 };
427
428 ptrs_read get_read_access(sham::EventList &depends_list) const {
429 return ptrs_read{
430 cnt_neigh.get_read_access(depends_list),
431 scanned_cnt.get_read_access(depends_list),
432 index_neigh_map.get_read_access(depends_list),
433 };
434 }
435
436 ptrs get_write_access(sham::EventList &depends_list) {
437 return ptrs{
438 cnt_neigh.get_write_access(depends_list),
439 scanned_cnt.get_write_access(depends_list),
440 index_neigh_map.get_write_access(depends_list),
441 };
442 }
443 void complete_event_state(sycl::event &e) const {
444 cnt_neigh.complete_event_state(e);
445 scanned_cnt.complete_event_state(e);
446 index_neigh_map.complete_event_state(e);
447 }
448
449 void complete_event_state(sham::EventList &resulting_events) const {
450 cnt_neigh.complete_event_state(resulting_events);
451 scanned_cnt.complete_event_state(resulting_events);
452 index_neigh_map.complete_event_state(resulting_events);
453 }
454 };
455
456 inline ObjectCache prepare_object_cache(sham::DeviceBuffer<u32> &&counts, u32 obj_cnt) {
457
458 shamlog_debug_sycl_ln("Cache", " reading last value ...");
459 u32 neigh_last_val = shamalgs::memory::extract_element(
460 shamsys::instance::get_compute_scheduler().get_queue(), counts, obj_cnt - 1);
461
462 shamlog_debug_sycl_ln("Cache", " last value =", neigh_last_val);
463
465 shamsys::instance::get_compute_scheduler_ptr(), counts, obj_cnt);
466
467 u32 neigh_sum = neigh_last_val
469 shamsys::instance::get_compute_scheduler().get_queue(),
470 neigh_scanned_vals,
471 obj_cnt - 1);
472
473 shamlog_debug_sycl_ln("Cache", " cache for N=", obj_cnt, "size() =", neigh_sum);
474
475 sham::DeviceBuffer<u32> particle_neigh_map(
476 neigh_sum, shamsys::instance::get_compute_scheduler_ptr());
477
478 tree::ObjectCache pcache{
479 std::move(counts),
480 std::move(neigh_scanned_vals),
481 neigh_sum,
482 std::move(particle_neigh_map)};
483
484 return pcache;
485 }
486
488
489 const u32 *neigh_cnt;
490 const u32 *table_neigh_offset;
491 const u32 *table_neigh;
492
493 public:
494 // clang-format off
496 neigh_cnt {cache.cnt_neigh },
497 table_neigh_offset {cache.scanned_cnt },
498 table_neigh {cache.index_neigh_map }
499 {}
500 // clang-format on
501
502 template<class Functor_iter>
503 inline void for_each_object(u32 idx, Functor_iter &&func_it) const {
504
505 u32 cnt = neigh_cnt[idx];
506 u32 offset_start = table_neigh_offset[idx];
507 u32 last_idx = offset_start + cnt;
508
509 for (u32 i = offset_start; i < last_idx; i++) {
510 func_it(table_neigh[i]);
511 }
512 }
513
514 template<class Functor_iter>
515 inline void for_each_object_with_id(u32 idx, Functor_iter &&func_it) const {
516
517 u32 cnt = neigh_cnt[idx];
518 u32 offset_start = table_neigh_offset[idx];
519 u32 last_idx = offset_start + cnt;
520
521 for (u32 i = offset_start; i < last_idx; i++) {
522 func_it(table_neigh[i], i);
523 }
524 }
525 };
526
527} // namespace shamrock::tree
std::uint32_t u32
32 bit unsigned integer
std::uint64_t u64
64 bit unsigned integer
The radix tree.
Definition RadixTree.hpp:50
A buffer allocated in USM (Unified Shared Memory)
void complete_event_state(sycl::event e) const
Complete the event state of the buffer.
void copy_from_stdvec(const std::vector< T > &vec)
Copy the content of a std::vector into the buffer.
T * get_write_access(sham::EventList &depends_list, SourceLocation src_loc=SourceLocation{})
Get a read-write pointer to the buffer's data.
std::vector< T > copy_to_stdvec() const
Copy the content of the buffer to a std::vector.
size_t get_mem_usage() const
Gets the amount of memory used by the buffer.
const T * get_read_access(sham::EventList &depends_list, SourceLocation src_loc=SourceLocation{}) const
Get a read-only pointer to the buffer's data.
Class to manage a list of SYCL events.
Definition EventList.hpp:31
u32 identify_cell(u_morton morton_code) const
identify leaf owning the asked code
constexpr T clz_xor(T a, T b) noexcept
give the length of the common prefix
Definition math.hpp:783
T extract_element(sycl::queue &q, sycl::buffer< T > &buf, u32 idx)
extract a value of a buffer
Definition memory.cpp:24
sycl::buffer< T > scan_exclusive(sycl::queue &q, sycl::buffer< T > &buf1, u32 len)
Computes the exclusive sum of elements in a SYCL buffer.
Definition numeric.cpp:35
T & get_check_ref(const std::unique_ptr< T > &ptr, SourceLocation loc=SourceLocation())
Takes a std::unique_ptr and returns a reference to the object it holds. It throws a std::runtime_erro...
Definition memory.hpp:110
constexpr u32 u32_max
u32 max value