Shamrock 2025.10.0
Astrophysical Code
Loading...
Searching...
No Matches
RadixTree.hpp
Go to the documentation of this file.
1// -------------------------------------------------------//
2//
3// SHAMROCK code for hydrodynamics
4// Copyright (c) 2021-2026 Timothée David--Cléris <tim.shamrock@proton.me>
5// SPDX-License-Identifier: CeCILL Free Software License Agreement v2.1
6// Shamrock is licensed under the CeCILL 2.1 License, see LICENSE for more information
7//
8// -------------------------------------------------------//
9
10#pragma once
11
18#include "shamalgs/memory.hpp"
36#include <array>
37#include <memory>
38#include <set>
39#include <stdexcept>
40#include <tuple>
41#include <vector>
42
49template<class Umorton, class Tvec>
50class RadixTree {
51
53
55
56 static constexpr bool pos_is_int
57 = std::is_same<Tvec, u16_3>::value || std::is_same<Tvec, u32_3>::value
58 || std::is_same<Tvec, u64_3>::value || std::is_same<Tvec, i16_3>::value
59 || std::is_same<Tvec, i32_3>::value || std::is_same<Tvec, i64_3>::value;
60
61 static constexpr bool pos_is_float
62 = std::is_same<Tvec, f32_3>::value || std::is_same<Tvec, f64_3>::value;
63
64 RadixTree() = default;
65
66 public:
68 using coord_t = typename shambase::VectorProperties<Tvec>::component_type;
69
70 static constexpr u32 tree_depth = Morton::significant_bits + 1;
71
72 std::tuple<Tvec, Tvec> bounding_box;
73
74 // build by the RadixTreeMortonBuilder
77
78 // Karras alg
80
82
83 inline bool is_tree_built() { return tree_struct.is_built(); }
84
85 inline bool are_range_int_built() { return tree_cell_ranges.are_range_int_built(); }
86
87 inline bool are_range_float_built() { return tree_cell_ranges.are_range_float_built(); }
88
89 void serialize(shamalgs::SerializeHelper &serializer);
90
91 shamalgs::SerializeSize serialize_byte_size();
92
93 static RadixTree deserialize(shamalgs::SerializeHelper &serializer);
94
95 inline friend bool operator==(const RadixTree &t1, const RadixTree &t2) {
96 bool cmp = true;
97 cmp = cmp && sham::equals(std::get<0>(t1.bounding_box), std::get<0>(t2.bounding_box));
98 cmp = cmp && sham::equals(std::get<1>(t1.bounding_box), std::get<1>(t2.bounding_box));
99 cmp = cmp && t1.tree_morton_codes == t2.tree_morton_codes;
100 cmp = cmp && t1.tree_reduced_morton_codes == t2.tree_reduced_morton_codes;
101 cmp = cmp && t1.tree_struct == t2.tree_struct;
102 cmp = cmp && t1.tree_cell_ranges == t2.tree_cell_ranges;
103 return cmp;
104 }
105
106 RadixTreeField<coord_t> compute_int_boxes(
107 sycl::queue &queue, sham::DeviceBuffer<coord_t> &int_rad_buf, coord_t tolerance);
108
109 void compute_cell_ibounding_box(sycl::queue &queue);
110 void convert_bounding_box(sycl::queue &queue);
111
112 inline std::unique_ptr<sycl::buffer<Umorton>> build_new_morton_buf(
113 sycl::buffer<Tvec> &pos_buf, u32 obj_cnt) {
114
115 return tree_morton_codes.build_raw(
116 shamsys::instance::get_compute_queue(),
117 shammath::CoordRange<Tvec>{bounding_box},
118 obj_cnt,
119 pos_buf);
120 }
121
122 RadixTree(
123 sycl::queue &queue,
124 std::tuple<Tvec, Tvec> treebox,
125 const std::unique_ptr<sycl::buffer<Tvec>> &pos_buf,
126 u32 cnt_obj,
127 u32 reduc_level);
128
129 RadixTree(
130 sycl::queue &queue,
131 std::tuple<Tvec, Tvec> treebox,
132 sycl::buffer<Tvec> &pos_buf,
133 u32 cnt_obj,
134 u32 reduc_level);
135
136 RadixTree(
137 sham::DeviceScheduler_ptr dev_sched,
138 std::tuple<Tvec, Tvec> treebox,
140 u32 cnt_obj,
141 u32 reduc_level);
142
143 inline RadixTree(const RadixTree &other)
144 : bounding_box(other.bounding_box), tree_morton_codes{other.tree_morton_codes},
145 tree_reduced_morton_codes(other.tree_reduced_morton_codes), // size = leaf cnt
146 tree_struct{other.tree_struct}, tree_cell_ranges(other.tree_cell_ranges) {}
147
148 [[nodiscard]] inline u64 memsize() const {
149 u64 sum = 0;
150
151 sum += sizeof(bounding_box);
152
153 auto add_ptr = [&](auto &a) {
154 if (a) {
155 sum += a->byte_size();
156 }
157 };
158
159 sum += tree_morton_codes.memsize();
160 sum += tree_reduced_morton_codes.memsize();
161 sum += tree_struct.memsize();
162 sum += tree_cell_ranges.memsize();
163
164 return sum;
165 }
166
167 inline RadixTree duplicate() {
168 const auto &cur = *this;
169 return RadixTree(cur);
170 }
171
172 inline std::unique_ptr<RadixTree> duplicate_to_ptr() {
173 const auto &cur = *this;
174 return std::make_unique<RadixTree>(cur);
175 }
176
177 bool is_same(RadixTree &other) {
178 bool cmp = true;
179
180 cmp = cmp && (sham::equals(std::get<0>(bounding_box), std::get<0>(other.bounding_box)));
181 cmp = cmp && (sham::equals(std::get<1>(bounding_box), std::get<1>(other.bounding_box)));
182 cmp = cmp && (tree_cell_ranges == other.tree_cell_ranges);
183 cmp = cmp
184 && (tree_reduced_morton_codes.tree_leaf_count
185 == other.tree_reduced_morton_codes.tree_leaf_count);
186 cmp = cmp && (tree_struct == other.tree_struct);
187
188 return cmp;
189 }
190
191 template<class T>
193
194 template<class T, class LambdaComputeLeaf, class LambdaCombinator>
195 RadixTreeField<T> compute_field(
196 sycl::queue &queue,
197 u32 nvar,
198 LambdaComputeLeaf &&compute_leaf,
199 LambdaCombinator &&combine) const;
200
201 template<class LambdaForEachCell>
202 std::pair<std::set<u32>, std::set<u32>> get_walk_res_set(LambdaForEachCell &&interact_cd) const;
203
204 template<class LambdaForEachCell>
205 void for_each_leaf(sycl::queue &queue, LambdaForEachCell &&par_for_each_cell) const;
206
207 std::tuple<coord_t, coord_t> get_min_max_cell_side_length();
208
209 struct CuttedTree {
211 std::unique_ptr<sycl::buffer<u32>> new_node_id_to_old;
212
213 std::unique_ptr<sycl::buffer<u32>> pdat_extract_id;
214 };
215
216 CuttedTree cut_tree(sycl::queue &queue, sycl::buffer<u8> &valid_node);
217
218 template<class T>
219 void print_tree_field(sycl::buffer<T> &buf_field);
220
221 static RadixTree make_empty() { return RadixTree(); }
222
224
225 sycl::accessor<u32, 1, sycl::access::mode::read, sycl::target::device> particle_index_map;
226 sycl::accessor<u32, 1, sycl::access::mode::read, sycl::target::device> reduc_index_map;
227
228 public:
229 LeafIterator(RadixTree &rtree, sycl::handler &cgh)
230 : particle_index_map(rtree.tree_morton_codes.buf_particle_index_map
231 ->template get_access<sycl::access::mode::read>(cgh)),
232 reduc_index_map(rtree.tree_reduced_morton_codes.buf_reduc_index_map
233 ->template get_access<sycl::access::mode::read>(cgh)) {}
234
235 template<class Func>
236 inline void iter_object_in_leaf(u32 leaf_id, Func &&func_it) const noexcept {
237 // loop on particle indexes
238 uint min_ids = reduc_index_map[leaf_id];
239 uint max_ids = reduc_index_map[leaf_id + 1];
240
241 for (unsigned int id_s = min_ids; id_s < max_ids; id_s++) {
242
243 // recover old index before morton sort
244 uint id_b = particle_index_map[id_s];
245
246 // iteration function
247 func_it(id_b);
248 }
249 }
250 };
251
252 inline LeafIterator get_leaf_access(sycl::handler &device_handler) {
253 return LeafIterator(*this, device_handler);
254 }
255};
256
257template<class u_morton, class vec3>
258template<class T, class LambdaComputeLeaf, class LambdaCombinator>
261 sycl::queue &queue,
262 u32 nvar,
263
264 LambdaComputeLeaf &&compute_leaf,
265 LambdaCombinator &&combine) const {
266
268 ret.nvar = nvar;
269
270 shamlog_debug_sycl_ln("RadixTree", "compute_field");
271
272 ret.radix_tree_field_buf = std::make_unique<sycl::buffer<T>>(
273 tree_struct.internal_cell_count + tree_reduced_morton_codes.tree_leaf_count);
274 sycl::range<1> range_leaf_cell{tree_reduced_morton_codes.tree_leaf_count};
275
276 queue.submit([&](sycl::handler &cgh) {
277 u32 offset_leaf = tree_struct.internal_cell_count;
278
279 auto tree_field
280 = sycl::accessor{*ret.radix_tree_field_buf, cgh, sycl::write_only, sycl::no_init};
281
282 auto cell_particle_ids = tree_reduced_morton_codes.buf_reduc_index_map
283 ->template get_access<sycl::access::mode::read>(cgh);
284 auto particle_index_map = tree_morton_codes.buf_particle_index_map
285 ->template get_access<sycl::access::mode::read>(cgh);
286
287 compute_leaf(cgh, [&](auto &&lambda_loop) {
288 cgh.parallel_for(range_leaf_cell, [=](sycl::item<1> item) {
289 u32 gid = (u32) item.get_id(0);
290
291 u32 min_ids = cell_particle_ids[gid];
292 u32 max_ids = cell_particle_ids[gid + 1];
293
294 lambda_loop(
295 [&](auto &&particle_it) {
296 for (unsigned int id_s = min_ids; id_s < max_ids; id_s++) {
297 particle_it(particle_index_map[id_s]);
298 }
299 },
300 tree_field,
301 [&]() {
302 return nvar * (offset_leaf + gid);
303 });
304 });
305 });
306 });
307
308 sycl::range<1> range_tree{tree_struct.internal_cell_count};
309 auto ker_reduc_hmax = [&](sycl::handler &cgh) {
310 u32 offset_leaf = tree_struct.internal_cell_count;
311
312 auto tree_field
313 = ret.radix_tree_field_buf->template get_access<sycl::access::mode::read_write>(cgh);
314
315 auto rchild_id = tree_struct.buf_rchild_id->get_access<sycl::access::mode::read>(cgh);
316 auto lchild_id = tree_struct.buf_lchild_id->get_access<sycl::access::mode::read>(cgh);
317 auto rchild_flag = tree_struct.buf_rchild_flag->get_access<sycl::access::mode::read>(cgh);
318 auto lchild_flag = tree_struct.buf_lchild_flag->get_access<sycl::access::mode::read>(cgh);
319
320 cgh.parallel_for(range_tree, [=](sycl::item<1> item) {
321 u32 gid = (u32) item.get_id(0);
322
323 u32 lid = lchild_id[gid] + offset_leaf * lchild_flag[gid];
324 u32 rid = rchild_id[gid] + offset_leaf * rchild_flag[gid];
325
326 combine(
327 [&](u32 nvar_id) -> T {
328 return tree_field[nvar * lid + nvar_id];
329 },
330 [&](u32 nvar_id) -> T {
331 return tree_field[nvar * rid + nvar_id];
332 },
333 tree_field,
334 [&]() {
335 return nvar * (gid);
336 });
337 });
338 };
339
340 for (u32 i = 0; i < tree_depth; i++) {
341 queue.submit(ker_reduc_hmax);
342 }
343
344 return std::move(ret);
345}
346
347template<class u_morton, class vec3>
348template<class LambdaForEachCell>
349inline std::pair<std::set<u32>, std::set<u32>> RadixTree<u_morton, vec3>::get_walk_res_set(
350 LambdaForEachCell &&interact_cd) const {
351
352 std::set<u32> leaf_list;
353 std::set<u32> rejected_list;
354
355 auto particle_index_map = sycl::host_accessor{*tree_morton_codes.buf_particle_index_map};
356 auto cell_index_map = sycl::host_accessor{*tree_reduced_morton_codes.buf_reduc_index_map};
357 auto rchild_id = sycl::host_accessor{*tree_struct.buf_rchild_id};
358 auto lchild_id = sycl::host_accessor{*tree_struct.buf_lchild_id};
359 auto rchild_flag = sycl::host_accessor{*tree_struct.buf_rchild_flag};
360 auto lchild_flag = sycl::host_accessor{*tree_struct.buf_lchild_flag};
361
362 // sycl::range<1> range_leaf = sycl::range<1>{tree_leaf_count};
363
364 u32 leaf_offset = tree_struct.internal_cell_count;
365
366 u32 stack_cursor = tree_depth - 1;
367 std::array<u32, tree_depth> id_stack;
368 id_stack[stack_cursor] = 0;
369
370 while (stack_cursor < tree_depth) {
371
372 u32 current_node_id = id_stack[stack_cursor];
373 id_stack[stack_cursor] = tree_depth;
374 stack_cursor++;
375
376 if (interact_cd(current_node_id)) {
377
378 // leaf and can interact => force
379 if (current_node_id >= leaf_offset) {
380
381 leaf_list.insert(current_node_id);
382
383 // can interact not leaf => stack
384 } else {
385
386 u32 lid = lchild_id[current_node_id] + leaf_offset * lchild_flag[current_node_id];
387 u32 rid = rchild_id[current_node_id] + leaf_offset * rchild_flag[current_node_id];
388
389 id_stack[stack_cursor - 1] = rid;
390 stack_cursor--;
391
392 id_stack[stack_cursor - 1] = lid;
393 stack_cursor--;
394 }
395 } else {
396 // grav
397
398 rejected_list.insert(current_node_id);
399 }
400 }
401
402 return std::pair<std::set<u32>, std::set<u32>>{std::move(leaf_list), std::move(rejected_list)};
403}
404
405template<class u_morton, class vec3>
406template<class LambdaForEachCell>
408 sycl::queue &queue, LambdaForEachCell &&par_for_each_cell) const {
409
410 queue.submit([&](sycl::handler &cgh) {
411 auto particle_index_map = tree_morton_codes.buf_particle_index_map
412 ->template get_access<sycl::access::mode::read>(cgh);
413 auto cell_index_map = tree_reduced_morton_codes.buf_reduc_index_map
414 ->template get_access<sycl::access::mode::read>(cgh);
415 auto rchild_id
416 = tree_struct.buf_rchild_id->template get_access<sycl::access::mode::read>(cgh);
417 auto lchild_id
418 = tree_struct.buf_lchild_id->template get_access<sycl::access::mode::read>(cgh);
419 auto rchild_flag
420 = tree_struct.buf_rchild_flag->template get_access<sycl::access::mode::read>(cgh);
421 auto lchild_flag
422 = tree_struct.buf_lchild_flag->template get_access<sycl::access::mode::read>(cgh);
423
424 sycl::range<1> range_leaf = sycl::range<1>{tree_reduced_morton_codes.tree_leaf_count};
425
426 u32 leaf_offset = tree_struct.internal_cell_count;
427
428 auto par_for = [&](auto &&for_each_leaf) {
429 cgh.parallel_for(range_leaf, [=](sycl::item<1> item) {
430 u32 id_cell_a = (u32) item.get_id(0) + leaf_offset;
431
432 auto iter_obj_cell = [&](u32 cell_id, auto &&func_it) {
433 uint min_ids = cell_index_map[cell_id - leaf_offset];
434 uint max_ids = cell_index_map[cell_id + 1 - leaf_offset];
435
436 for (unsigned int id_s = min_ids; id_s < max_ids; id_s++) {
437
438 // recover old index before morton sort
439 uint id_b = particle_index_map[id_s];
440
441 // iteration function
442 func_it(id_b);
443 }
444 };
445
446 auto walk_loop = [&](u32 id_cell_a, auto &&for_other_cell) {
447 u32 stack_cursor = tree_depth - 1;
448 std::array<u32, tree_depth> id_stack;
449 id_stack[stack_cursor] = 0;
450
451 while (stack_cursor < tree_depth) {
452
453 u32 current_node_id = id_stack[stack_cursor];
454 id_stack[stack_cursor] = tree_depth;
455 stack_cursor++;
456
457 auto walk_logic = [&](const bool &cur_id_valid,
458 auto &&func_leaf_found,
459 auto &&func_node_rejected) {
460 if (cur_id_valid) {
461
462 // leaf and can interact => force
463 if (current_node_id >= leaf_offset) {
464
465 func_leaf_found();
466
467 // can interact not leaf => stack
468 } else {
469
470 u32 lid = lchild_id[current_node_id]
471 + leaf_offset * lchild_flag[current_node_id];
472 u32 rid = rchild_id[current_node_id]
473 + leaf_offset * rchild_flag[current_node_id];
474
475 id_stack[stack_cursor - 1] = rid;
476 stack_cursor--;
477
478 id_stack[stack_cursor - 1] = lid;
479 stack_cursor--;
480 }
481 } else {
482 // grav
483
484 func_node_rejected();
485 }
486 };
487
488 for_other_cell(current_node_id, walk_logic);
489 }
490 };
491
492 for_each_leaf(id_cell_a, walk_loop, iter_obj_cell);
493 });
494 };
495
496 par_for_each_cell(cgh, par_for);
497 });
498}
499
500template<class u_morton, class vec3>
502 -> std::tuple<coord_t, coord_t> {
503
504 u32 len = tree_reduced_morton_codes.tree_leaf_count;
505
506 sycl::buffer<coord_t> min_side_length{len};
507 sycl::buffer<coord_t> max_side_length{len};
508
510
511 q.submit([&](sycl::handler &cgh) {
512 u32 offset_leaf = tree_struct.internal_cell_count;
513
514 sycl::accessor pos_min_cell{*tree_cell_ranges.buf_pos_min_cell_flt, cgh, sycl::read_only};
515 sycl::accessor pos_max_cell{*tree_cell_ranges.buf_pos_max_cell_flt, cgh, sycl::read_only};
516
517 sycl::accessor s_lengh_min{min_side_length, cgh, sycl::write_only, sycl::no_init};
518 sycl::accessor s_lengh_max{max_side_length, cgh, sycl::write_only, sycl::no_init};
519
520 sycl::range<1> range_tree{tree_reduced_morton_codes.tree_leaf_count};
521
522 cgh.parallel_for(range_tree, [=](sycl::item<1> item) {
523 u32 gid = (u32) item.get_id(0);
524
525 vec3 min = pos_min_cell[gid + offset_leaf];
526 vec3 max = pos_max_cell[gid + offset_leaf];
527
528 vec3 sz = max - min;
529
530 if constexpr (pos_is_float) {
531 s_lengh_min[gid] = sycl::fmin(sycl::fmin(sz.x(), sz.y()), sz.z());
532 s_lengh_max[gid] = sycl::fmax(sycl::fmax(sz.x(), sz.y()), sz.z());
533 }
534
535 if constexpr (pos_is_int) {
536 s_lengh_min[gid] = sycl::min(sycl::min(sz.x(), sz.y()), sz.z());
537 s_lengh_max[gid] = sycl::max(sycl::max(sz.x(), sz.y()), sz.z());
538 }
539 });
540 });
541
542 auto dev_sched = shamsys::instance::get_compute_scheduler_ptr();
543
544 sham::DeviceBuffer<coord_t> tmp_min_side_length(len, dev_sched);
545 sham::DeviceBuffer<coord_t> tmp_max_side_length(len, dev_sched);
546
547 tmp_min_side_length.copy_from_sycl_buffer(min_side_length);
548 tmp_max_side_length.copy_from_sycl_buffer(max_side_length);
549
550 coord_t min = shamalgs::primitives::min(dev_sched, tmp_min_side_length, 0, len);
551 coord_t max = shamalgs::primitives::max(dev_sched, tmp_max_side_length, 0, len);
552
553 return {min, max};
554}
555
556namespace tree_comm {
557
558 template<class u_morton, class vec3>
560 public:
561 template<class T>
564
565 mpi_sycl_interop::comm_type comm_mode;
566 mpi_sycl_interop::op_type comm_op;
567
568 RTree &rtree;
569
570 std::vector<Request<u_morton>> rq_u_morton;
571 std::vector<Request<u32>> rq_u32;
572 std::vector<Request<u8>> rq_u8;
573 std::vector<Request<vec3>> rq_vec;
574
575 std::vector<Request<typename RTree::ipos_t>> rq_vec3i;
576
577 inline RadixTreeMPIRequest(RTree &rtree, mpi_sycl_interop::op_type comm_op)
578 : rtree(rtree), comm_mode(mpi_sycl_interop::current_mode), comm_op(comm_op) {}
579
580 inline void finalize() {
581 mpi_sycl_interop::waitall(rq_u_morton);
582 mpi_sycl_interop::waitall(rq_u32);
583 mpi_sycl_interop::waitall(rq_u8);
584 mpi_sycl_interop::waitall(rq_vec3i);
585 mpi_sycl_interop::waitall(rq_vec);
586
587 if (comm_op == mpi_sycl_interop::Recv_Probe) {
588 rtree.tree_morton_codes.obj_cnt = rtree.tree_morton_codes.buf_morton->size();
589 rtree.tree_reduced_morton_codes.tree_leaf_count
590 = rtree.tree_reduced_morton_codes.buf_tree_morton->size();
591 rtree.tree_struct.internal_cell_count = rtree.tree_struct.buf_lchild_id->size();
592
593 {
594 sycl::host_accessor bmin{*rtree.tree_cell_ranges.buf_pos_min_cell_flt};
595 sycl::host_accessor bmax{*rtree.tree_cell_ranges.buf_pos_max_cell_flt};
596
597 rtree.bounding_box = {bmin[0], bmax[0]};
598 }
599
600 // One cell mode check
601
602 {
603 sycl::host_accessor indmap{
604 *rtree.tree_reduced_morton_codes.buf_reduc_index_map};
605 rtree.tree_struct.one_cell_mode
606 = (indmap[rtree.tree_reduced_morton_codes.buf_reduc_index_map->size() - 1]
607 == 0);
608 }
609 }
610 }
611 };
612
613 template<class u_morton, class vec3>
614 inline void wait_all(std::vector<RadixTreeMPIRequest<u_morton, vec3>> &rqs) {
615 for (auto &rq : rqs) {
616 rq.finalize();
617 }
618 }
619
620 template<class u_morton, class vec3>
621 inline u64 comm_isend(
623 std::vector<RadixTreeMPIRequest<u_morton, vec3>> &rqs,
624 i32 rank_dest,
625 i32 tag,
626 MPI_Comm comm) {
627
628 u64 ret_len = 0;
629
630 rqs.push_back(RadixTreeMPIRequest<u_morton, vec3>(rtree, mpi_sycl_interop::op_type::Send));
631
632 auto &rq = rqs.back();
633
634 ret_len += mpi_sycl_interop::isend(
635 rq.rtree.tree_morton_codes.buf_morton,
636 rq.rtree.tree_morton_codes.obj_cnt,
637 rq.rq_u_morton,
638 rank_dest,
639 tag,
640 comm);
641 ret_len += mpi_sycl_interop::isend(
642 rq.rtree.tree_morton_codes.buf_particle_index_map,
643 rq.rtree.tree_morton_codes.obj_cnt,
644 rq.rq_u32,
645 rank_dest,
646 tag,
647 comm);
648
649 ret_len += mpi_sycl_interop::isend(
650 rq.rtree.tree_reduced_morton_codes.buf_reduc_index_map,
651 rq.rtree.tree_reduced_morton_codes.tree_leaf_count + 1,
652 rq.rq_u32,
653 rank_dest,
654 tag,
655 comm);
656
657 ret_len += mpi_sycl_interop::isend(
658 rq.rtree.tree_reduced_morton_codes.buf_tree_morton,
659 rq.rtree.tree_reduced_morton_codes.tree_leaf_count,
660 rq.rq_u_morton,
661 rank_dest,
662 tag,
663 comm);
664 ret_len += mpi_sycl_interop::isend(
665 rq.rtree.tree_struct.buf_lchild_id,
666 rq.rtree.tree_struct.internal_cell_count,
667 rq.rq_u32,
668 rank_dest,
669 tag,
670 comm);
671 ret_len += mpi_sycl_interop::isend(
672 rq.rtree.tree_struct.buf_rchild_id,
673 rq.rtree.tree_struct.internal_cell_count,
674 rq.rq_u32,
675 rank_dest,
676 tag,
677 comm);
678 ret_len += mpi_sycl_interop::isend(
679 rq.rtree.tree_struct.buf_lchild_flag,
680 rq.rtree.tree_struct.internal_cell_count,
681 rq.rq_u8,
682 rank_dest,
683 tag,
684 comm);
685 ret_len += mpi_sycl_interop::isend(
686 rq.rtree.tree_struct.buf_rchild_flag,
687 rq.rtree.tree_struct.internal_cell_count,
688 rq.rq_u8,
689 rank_dest,
690 tag,
691 comm);
692 ret_len += mpi_sycl_interop::isend(
693 rq.rtree.tree_struct.buf_endrange,
694 rq.rtree.tree_struct.internal_cell_count,
695 rq.rq_u32,
696 rank_dest,
697 tag,
698 comm);
699
700 ret_len += mpi_sycl_interop::isend(
701 rq.rtree.tree_cell_ranges.buf_pos_min_cell,
702 rq.rtree.tree_struct.internal_cell_count
703 + rq.rtree.tree_reduced_morton_codes.tree_leaf_count,
704 rq.rq_vec3i,
705 rank_dest,
706 tag,
707 comm);
708 ret_len += mpi_sycl_interop::isend(
709 rq.rtree.tree_cell_ranges.buf_pos_max_cell,
710 rq.rtree.tree_struct.internal_cell_count
711 + rq.rtree.tree_reduced_morton_codes.tree_leaf_count,
712 rq.rq_vec3i,
713 rank_dest,
714 tag,
715 comm);
716
717 ret_len += mpi_sycl_interop::isend(
718 rq.rtree.tree_cell_ranges.buf_pos_min_cell_flt,
719 rq.rtree.tree_struct.internal_cell_count
720 + rq.rtree.tree_reduced_morton_codes.tree_leaf_count,
721 rq.rq_vec,
722 rank_dest,
723 tag,
724 comm);
725 ret_len += mpi_sycl_interop::isend(
726 rq.rtree.tree_cell_ranges.buf_pos_max_cell_flt,
727 rq.rtree.tree_struct.internal_cell_count
728 + rq.rtree.tree_reduced_morton_codes.tree_leaf_count,
729 rq.rq_vec,
730 rank_dest,
731 tag,
732 comm);
733
734 return ret_len;
735 }
736
737 template<class u_morton, class vec3>
738 inline u64 comm_irecv_probe(
740 std::vector<RadixTreeMPIRequest<u_morton, vec3>> &rqs,
741 i32 rank_source,
742 i32 tag,
743 MPI_Comm comm) {
744
745 rqs.push_back(
746 RadixTreeMPIRequest<u_morton, vec3>(rtree, mpi_sycl_interop::op_type::Recv_Probe));
747
748 auto &rq = rqs.back();
749
750 u64 ret_len = 0;
751
752 ret_len += mpi_sycl_interop::irecv_probe(
753 rq.rtree.tree_morton_codes.buf_morton, rq.rq_u_morton, rank_source, tag, comm);
754 ret_len += mpi_sycl_interop::irecv_probe(
755 rq.rtree.tree_morton_codes.buf_particle_index_map, rq.rq_u32, rank_source, tag, comm);
756
757 ret_len += mpi_sycl_interop::irecv_probe(
758 rq.rtree.tree_reduced_morton_codes.buf_reduc_index_map,
759 rq.rq_u32,
760 rank_source,
761 tag,
762 comm);
763
764 ret_len += mpi_sycl_interop::irecv_probe(
765 rq.rtree.tree_reduced_morton_codes.buf_tree_morton,
766 rq.rq_u_morton,
767 rank_source,
768 tag,
769 comm);
770 ret_len += mpi_sycl_interop::irecv_probe(
771 rq.rtree.tree_struct.buf_lchild_id, rq.rq_u32, rank_source, tag, comm);
772 ret_len += mpi_sycl_interop::irecv_probe(
773 rq.rtree.tree_struct.buf_rchild_id, rq.rq_u32, rank_source, tag, comm);
774 ret_len += mpi_sycl_interop::irecv_probe(
775 rq.rtree.tree_struct.buf_lchild_flag, rq.rq_u8, rank_source, tag, comm);
776 ret_len += mpi_sycl_interop::irecv_probe(
777 rq.rtree.tree_struct.buf_rchild_flag, rq.rq_u8, rank_source, tag, comm);
778 ret_len += mpi_sycl_interop::irecv_probe(
779 rq.rtree.tree_struct.buf_endrange, rq.rq_u32, rank_source, tag, comm);
780
781 ret_len += mpi_sycl_interop::irecv_probe(
782 rq.rtree.tree_cell_ranges.buf_pos_min_cell, rq.rq_vec3i, rank_source, tag, comm);
783 ret_len += mpi_sycl_interop::irecv_probe(
784 rq.rtree.tree_cell_ranges.buf_pos_max_cell, rq.rq_vec3i, rank_source, tag, comm);
785
786 ret_len += mpi_sycl_interop::irecv_probe(
787 rq.rtree.tree_cell_ranges.buf_pos_min_cell_flt, rq.rq_vec, rank_source, tag, comm);
788 ret_len += mpi_sycl_interop::irecv_probe(
789 rq.rtree.tree_cell_ranges.buf_pos_max_cell_flt, rq.rq_vec, rank_source, tag, comm);
790
791 return ret_len;
792 }
793
794} // namespace tree_comm
795
796// TODO move h iter thing + multipoles to a tree field class
797
798namespace walker {
799
800 namespace interaction_crit {
801 template<class vec3, class flt>
802 inline bool sph_radix_cell_crit(
803 vec3 xyz_a,
804 vec3 part_a_box_min,
805 vec3 part_a_box_max,
806 vec3 cur_cell_box_min,
807 vec3 cur_cell_box_max,
808 flt box_int_sz) {
809
810 vec3 inter_box_b_min = cur_cell_box_min - box_int_sz;
811 vec3 inter_box_b_max = cur_cell_box_max + box_int_sz;
812
813 return BBAA::cella_neigh_b(
814 part_a_box_min, part_a_box_max, cur_cell_box_min, cur_cell_box_max)
815 || BBAA::cella_neigh_b(xyz_a, xyz_a, inter_box_b_min, inter_box_b_max);
816 }
817
818 template<class vec3, class flt>
819 inline bool sph_cell_cell_crit(
820 vec3 cella_min,
821 vec3 cella_max,
822 vec3 cellb_min,
823 vec3 cellb_max,
824 flt rint_a,
825 flt rint_b) {
826
827 vec3 inter_box_a_min = cella_min - rint_a;
828 vec3 inter_box_a_max = cella_max + rint_a;
829
830 vec3 inter_box_b_min = cellb_min - rint_b;
831 vec3 inter_box_b_max = cellb_max + rint_b;
832
833 return BBAA::cella_neigh_b(inter_box_a_min, inter_box_a_max, cellb_min, cellb_max)
834 || BBAA::cella_neigh_b(inter_box_b_min, inter_box_b_max, cella_min, cella_max);
835 }
836 } // namespace interaction_crit
837
838 template<class u_morton, class vec3>
840 public:
841 sycl::accessor<u32, 1, sycl::access::mode::read, sycl::target::device> particle_index_map;
842 sycl::accessor<u32, 1, sycl::access::mode::read, sycl::target::device> cell_index_map;
843 sycl::accessor<u32, 1, sycl::access::mode::read, sycl::target::device> rchild_id;
844 sycl::accessor<u32, 1, sycl::access::mode::read, sycl::target::device> lchild_id;
845 sycl::accessor<u8, 1, sycl::access::mode::read, sycl::target::device> rchild_flag;
846 sycl::accessor<u8, 1, sycl::access::mode::read, sycl::target::device> lchild_flag;
847 sycl::accessor<vec3, 1, sycl::access::mode::read, sycl::target::device> pos_min_cell;
848 sycl::accessor<vec3, 1, sycl::access::mode::read, sycl::target::device> pos_max_cell;
849
850 static constexpr u32 tree_depth = RadixTree<u_morton, vec3>::tree_depth;
851 static constexpr u32 _nindex = 4294967295;
852
853 u32 leaf_offset;
854
855 Radix_tree_accessor(RadixTree<u_morton, vec3> &rtree, sycl::handler &cgh)
856 : particle_index_map(rtree.tree_morton_codes.buf_particle_index_map
857 ->template get_access<sycl::access::mode::read>(cgh)),
858 cell_index_map(rtree.tree_reduced_morton_codes.buf_reduc_index_map
859 ->template get_access<sycl::access::mode::read>(cgh)),
860 rchild_id(
861 rtree.tree_struct.buf_rchild_id->template get_access<sycl::access::mode::read>(
862 cgh)),
863 lchild_id(
864 rtree.tree_struct.buf_lchild_id->template get_access<sycl::access::mode::read>(
865 cgh)),
866 rchild_flag(
867 rtree.tree_struct.buf_rchild_flag->template get_access<sycl::access::mode::read>(
868 cgh)),
869 lchild_flag(
870 rtree.tree_struct.buf_lchild_flag->template get_access<sycl::access::mode::read>(
871 cgh)),
872 pos_min_cell(rtree.tree_cell_ranges.buf_pos_min_cell_flt
873 ->template get_access<sycl::access::mode::read>(cgh)),
874 pos_max_cell(rtree.tree_cell_ranges.buf_pos_max_cell_flt
875 ->template get_access<sycl::access::mode::read>(cgh)),
876 leaf_offset(rtree.tree_struct.internal_cell_count) {}
877 };
878
879 template<class Rta, class Functor_iter>
880 inline void iter_object_in_cell(const Rta &acc, const u32 &cell_id, Functor_iter &&func_it) {
881 // loop on particle indexes
882 uint min_ids = acc.cell_index_map[cell_id - acc.leaf_offset];
883 uint max_ids = acc.cell_index_map[cell_id + 1 - acc.leaf_offset];
884
885 for (unsigned int id_s = min_ids; id_s < max_ids; id_s++) {
886
887 // recover old index before morton sort
888 uint id_b = acc.particle_index_map[id_s];
889
890 // iteration function
891 func_it(id_b);
892 }
893
894 /*
895 std::array<u32, 16> stack_run;
896
897 u32 run_cursor = 16;
898
899 auto is_stack_full = [&]() -> bool{
900 return run_cursor == 0;
901 };
902
903 auto is_stack_not_empty = [&]() -> bool{
904 return run_cursor < 16;
905 };
906
907 auto push_stack = [&](u32 val){
908 run_cursor --;
909 stack_run[run_cursor] = val;
910 };
911
912 auto pop_stack = [&]() -> u32 {
913 u32 v = stack_run[run_cursor];
914 run_cursor ++;
915 return v;
916 };
917
918 auto empty_stack = [&](){
919 while (is_stack_not_empty()) {
920 func_it(pop_stack());
921 }
922 };
923
924 for (unsigned int id_s = min_ids; id_s < max_ids; id_s++) {
925 uint id_b = acc.particle_index_map[id_s];
926
927 if(is_stack_full()){
928 empty_stack();
929 }
930
931 push_stack(id_b);
932
933 }
934
935 empty_stack();
936 */
937 }
938
939 template<class Rta, class Functor_int_cd, class Functor_iter, class Functor_iter_excl>
940 inline void rtree_for_cell(
941 const Rta &acc,
942 Functor_int_cd &&func_int_cd,
943 Functor_iter &&func_it,
944 Functor_iter_excl &&func_excl) {
945 u32 stack_cursor = Rta::tree_depth - 1;
946 std::array<u32, Rta::tree_depth> id_stack;
947 id_stack[stack_cursor] = 0;
948
949 while (stack_cursor < Rta::tree_depth) {
950
951 u32 current_node_id = id_stack[stack_cursor];
952 id_stack[stack_cursor] = Rta::_nindex;
953 stack_cursor++;
954
955 bool cur_id_valid = func_int_cd(current_node_id);
956
957 if (cur_id_valid) {
958
959 // leaf and cell can interact
960 if (current_node_id >= acc.leaf_offset) {
961
962 func_it(current_node_id);
963
964 // can interact not leaf => stack
965 } else {
966
967 u32 lid = acc.lchild_id[current_node_id]
968 + acc.leaf_offset * acc.lchild_flag[current_node_id];
969 u32 rid = acc.rchild_id[current_node_id]
970 + acc.leaf_offset * acc.rchild_flag[current_node_id];
971
972 id_stack[stack_cursor - 1] = rid;
973 stack_cursor--;
974
975 id_stack[stack_cursor - 1] = lid;
976 stack_cursor--;
977 }
978 } else {
979 // grav
980 func_excl(current_node_id);
981 }
982 }
983 }
984
985 template<class Rta, class Functor_int_cd, class Functor_iter, class Functor_iter_excl>
986 inline void rtree_for(
987 const Rta &acc,
988 Functor_int_cd &&func_int_cd,
989 Functor_iter &&func_it,
990 Functor_iter_excl &&func_excl) {
991 u32 stack_cursor = Rta::tree_depth - 1;
992 std::array<u32, Rta::tree_depth> id_stack;
993 id_stack[stack_cursor] = 0;
994
995 while (stack_cursor < Rta::tree_depth) {
996
997 u32 current_node_id = id_stack[stack_cursor];
998 id_stack[stack_cursor] = Rta::_nindex;
999 stack_cursor++;
1000
1001 bool cur_id_valid = func_int_cd(current_node_id);
1002
1003 if (cur_id_valid) {
1004
1005 // leaf and can interact => force
1006 if (current_node_id >= acc.leaf_offset) {
1007
1008 // loop on particle indexes
1009 // uint min_ids = acc.cell_index_map[current_node_id -acc.leaf_offset];
1010 // uint max_ids = acc.cell_index_map[current_node_id + 1 -acc.leaf_offset];
1011 //
1012 // for (unsigned int id_s = min_ids; id_s < max_ids; id_s++) {
1013 //
1014 // //recover old index before morton sort
1015 // uint id_b = acc.particle_index_map[id_s];
1016 //
1017 // //iteration function
1018 // func_it(id_b);
1019 //}
1020
1021 iter_object_in_cell(acc, current_node_id, func_it);
1022
1023 // can interact not leaf => stack
1024 } else {
1025
1026 u32 lid = acc.lchild_id[current_node_id]
1027 + acc.leaf_offset * acc.lchild_flag[current_node_id];
1028 u32 rid = acc.rchild_id[current_node_id]
1029 + acc.leaf_offset * acc.rchild_flag[current_node_id];
1030
1031 id_stack[stack_cursor - 1] = rid;
1032 stack_cursor--;
1033
1034 id_stack[stack_cursor - 1] = lid;
1035 stack_cursor--;
1036 }
1037 } else {
1038 // grav
1039 func_excl(current_node_id);
1040 }
1041 }
1042 }
1043
1044 template<
1045 class Rta,
1046 class Functor_int_cd,
1047 class Functor_iter,
1048 class Functor_iter_excl,
1049 class arr_type>
1050 inline void rtree_for_fill_cache(Rta &acc, arr_type &cell_cache, Functor_int_cd &&func_int_cd) {
1051
1052 constexpr u32 cache_sz = cell_cache.size();
1053 u32 cache_pos = 0;
1054
1055 auto push_in_cache = [&cell_cache, &cache_pos](u32 id) {
1056 cell_cache[cache_pos] = id;
1057 cache_pos++;
1058 };
1059
1060 u32 stack_cursor = Rta::tree_depth - 1;
1061 std::array<u32, Rta::tree_depth> id_stack;
1062 id_stack[stack_cursor] = 0;
1063
1064 auto get_el_cnt_in_stack = [&]() -> u32 {
1065 return Rta::tree_depth - stack_cursor;
1066 };
1067
1068 while ((stack_cursor < Rta::tree_depth) && (cache_pos + get_el_cnt_in_stack < cache_sz)) {
1069
1070 u32 current_node_id = id_stack[stack_cursor];
1071 id_stack[stack_cursor] = Rta::_nindex;
1072 stack_cursor++;
1073
1074 bool cur_id_valid = func_int_cd(current_node_id);
1075
1076 if (cur_id_valid) {
1077
1078 // leaf and can interact => force
1079 if (current_node_id >= acc.leaf_offset) {
1080
1081 // can interact => add to cache
1082 push_in_cache(current_node_id);
1083
1084 // can interact not leaf => stack
1085 } else {
1086
1087 u32 lid = acc.lchild_id[current_node_id]
1088 + acc.leaf_offset * acc.lchild_flag[current_node_id];
1089 u32 rid = acc.rchild_id[current_node_id]
1090 + acc.leaf_offset * acc.rchild_flag[current_node_id];
1091
1092 id_stack[stack_cursor - 1] = rid;
1093 stack_cursor--;
1094
1095 id_stack[stack_cursor - 1] = lid;
1096 stack_cursor--;
1097 }
1098 } else {
1099 // grav
1100 //.....
1101 }
1102 }
1103
1104 while (stack_cursor < Rta::tree_depth) {
1105 u32 current_node_id = id_stack[stack_cursor];
1106 id_stack[stack_cursor] = Rta::_nindex;
1107 stack_cursor++;
1108 push_in_cache(current_node_id);
1109 }
1110
1111 if (cache_pos < cache_sz) {
1112 push_in_cache(u32_max);
1113 }
1114 }
1115
1116 template<
1117 class Rta,
1118 class Functor_int_cd,
1119 class Functor_iter,
1120 class Functor_iter_excl,
1121 class arr_type>
1122 inline void rtree_for(
1123 Rta &acc, arr_type &cell_cache, Functor_int_cd &&func_int_cd, Functor_iter &&func_it) {
1124
1125 constexpr u32 cache_sz = cell_cache.size();
1126
1127 std::array<u32, Rta::tree_depth> id_stack;
1128
1129 auto walk_step = [&](u32 start_id) {
1130 u32 stack_cursor = Rta::tree_depth - 1;
1131 id_stack[stack_cursor] = start_id;
1132
1133 while (stack_cursor < Rta::tree_depth) {
1134
1135 u32 current_node_id = id_stack[stack_cursor];
1136 id_stack[stack_cursor] = Rta::_nindex;
1137 stack_cursor++;
1138
1139 bool cur_id_valid = func_int_cd(current_node_id);
1140
1141 if (cur_id_valid) {
1142
1143 // leaf and can interact => force
1144 if (current_node_id >= acc.leaf_offset) {
1145
1146 // loop on particle indexes
1147 uint min_ids = acc.cell_index_map[current_node_id - acc.leaf_offset];
1148 uint max_ids = acc.cell_index_map[current_node_id + 1 - acc.leaf_offset];
1149
1150 for (unsigned int id_s = min_ids; id_s < max_ids; id_s++) {
1151
1152 // recover old index before morton sort
1153 uint id_b = acc.particle_index_map[id_s];
1154
1155 // iteration function
1156 func_it(id_b);
1157 }
1158
1159 // can interact not leaf => stack
1160 } else {
1161
1162 u32 lid = acc.lchild_id[current_node_id]
1163 + acc.leaf_offset * acc.lchild_flag[current_node_id];
1164 u32 rid = acc.rchild_id[current_node_id]
1165 + acc.leaf_offset * acc.rchild_flag[current_node_id];
1166
1167 id_stack[stack_cursor - 1] = rid;
1168 stack_cursor--;
1169
1170 id_stack[stack_cursor - 1] = lid;
1171 stack_cursor--;
1172 }
1173 } else {
1174 // grav
1175 //...
1176 }
1177 }
1178 };
1179
1180 for (u32 cache_pos = 0; cache_pos < cache_sz && cell_cache[cache_pos] != u32_max;
1181 cache_pos++) {
1182 walk_step(cache_pos);
1183 }
1184 }
1185
1186} // namespace walker
constexpr const char * uint
Specific internal energy u.
Header file describing a Node Instance.
sycl::queue & get_compute_queue(u32 id=0)
std::uint32_t u32
32 bit unsigned integer
std::uint64_t u64
64 bit unsigned integer
std::int32_t i32
32 bit integer
The radix tree.
Definition RadixTree.hpp:50
A buffer allocated in USM (Unified Shared Memory)
Morton curve implementation.
T min(const sham::DeviceScheduler_ptr &sched, const sham::DeviceBuffer< T > &buf1, u32 start_id, u32 end_id)
Find the minimum element in a device buffer within a specified range.
T max(const sham::DeviceScheduler_ptr &sched, const sham::DeviceBuffer< T > &buf1, u32 start_id, u32 end_id)
Find the maximum element in a device buffer within a specified range.
constexpr u32 u32_max
u32 max value
main include file for memory algorithms