34template<
class fp_prec_vec>
43 std::unique_ptr<sycl::buffer<PtNode>> serial_tree_buf;
44 std::unique_ptr<sycl::buffer<u64>> linked_patch_ids_buf;
46 inline void attach_buf() {
47 if (
bool(serial_tree_buf))
49 "serial_tree_buf is already allocated");
50 if (
bool(linked_patch_ids_buf))
52 "linked_patch_ids_buf is already allocated");
55 = std::make_unique<sycl::buffer<PtNode>>(serial_tree.data(), serial_tree.size());
57 = std::make_unique<sycl::buffer<u64>>(linked_patch_ids.data(), linked_patch_ids.size());
60 inline void detach_buf() {
61 if (!
bool(serial_tree_buf))
63 "serial_tree_buf wasn't allocated");
64 if (!
bool(linked_patch_ids_buf))
66 "linked_patch_ids_buf wasn't allocated");
68 serial_tree_buf.reset();
69 linked_patch_ids_buf.reset();
75 std::vector<PtNode> serial_tree;
76 std::vector<u64> linked_patch_ids;
77 std::vector<u64> roots_ids;
79 void build_from_patch_tree(
83 inline void print_status() {
85 for (
PtNode n : serial_tree) {
106 build_from_patch_tree(ptree, box_transform);
109 template<
class Acc1,
class Acc2>
110 inline void host_for_each_leafs_internal(
111 std::function<
bool(
u64,
PtNode pnode)> interact_cd,
112 std::function<
void(
u64,
PtNode)> found_case,
116 std::stack<u64> id_stack;
118 for (
u64 root : roots_ids) {
122 while (!id_stack.empty()) {
123 u64 cur_id = id_stack.top();
125 PtNode cur_p = tree[cur_id];
127 bool interact = interact_cd(cur_id, cur_p);
130 u64 linked_id = lpid[cur_id];
132 found_case(linked_id, cur_p);
134 id_stack.push(cur_p.childs_id[0]);
135 id_stack.push(cur_p.childs_id[1]);
136 id_stack.push(cur_p.childs_id[2]);
137 id_stack.push(cur_p.childs_id[3]);
138 id_stack.push(cur_p.childs_id[4]);
139 id_stack.push(cur_p.childs_id[5]);
140 id_stack.push(cur_p.childs_id[6]);
141 id_stack.push(cur_p.childs_id[7]);
147 inline void host_for_each_leafs(
148 std::function<
bool(
u64,
PtNode pnode)> interact_cd,
149 std::function<
void(
u64,
PtNode)> found_case) {
155 host_for_each_leafs_internal(interact_cd, found_case, tree, lpid);
174 sched.
patch_tree, sched.get_patch_transform<fp_prec_vec>());
177 template<
class type,
class reduc_func>
187 sycl::host_accessor lpid{*linked_patch_ids_buf, sycl::read_only};
192 predfield.tree_field[idx]
193 = (lpid[idx] !=
u64_max) ? pfield.global_values[idp_to_gid[lpid[idx]]] : type();
201 predfield.attach_buf();
207 for (
u32 level = 0; level < end_loop; level++) {
218 std::cout <<
"queue submit : " << level <<
" " << end_loop <<
" " << (level < end_loop)
220 queue.submit([&](sycl::handler &cgh) {
222 = this->serial_tree_buf->template get_access<sycl::access::mode::read>(cgh);
225 = predfield.tree_field_buf->template get_access<sycl::access::mode::read_write>(
228 cgh.parallel_for<
class OctreeReduction>(range, [=](sycl::item<1> item) {
229 u64 i = (
u64) item.get_id(0);
231 u64 idx_c0 = tree[i].childs_id0;
232 u64 idx_c1 = tree[i].childs_id1;
233 u64 idx_c2 = tree[i].childs_id2;
234 u64 idx_c3 = tree[i].childs_id3;
235 u64 idx_c4 = tree[i].childs_id4;
236 u64 idx_c5 = tree[i].childs_id5;
237 u64 idx_c6 = tree[i].childs_id6;
238 u64 idx_c7 = tree[i].childs_id7;
241 f[i] = reduc_func::reduce(
266 template<
class T,
class Func>
276 sycl::host_accessor lpid{
278 sycl::host_accessor tree_field{
284 tree_field[idx] = (lpid[idx] !=
u64_max) ? pfield.get(lpid[idx]) : T();
291 for (
u32 level = 0; level < end_loop; level++) {
292 queue.submit([&](sycl::handler &cgh) {
297 cgh.parallel_for(range, [=](sycl::item<1> item) {
298 u64 i = (
u64) item.get_id(0);
300 std::array<u64, 8> n = tree[i].childs_id;
304 f[n[0]], f[n[1]], f[n[2]], f[n[3]], f[n[4]], f[n[5]], f[n[6]], f[n[7]]);
312 inline void dump_dat() {
314 std::cout << idx <<
" (" << serial_tree[idx].childs_id[0] <<
", "
315 << serial_tree[idx].childs_id[1] <<
", " << serial_tree[idx].childs_id[2]
316 <<
", " << serial_tree[idx].childs_id[3] <<
", "
317 << serial_tree[idx].childs_id[4] <<
", " << serial_tree[idx].childs_id[5]
318 <<
", " << serial_tree[idx].childs_id[6] <<
", "
319 << serial_tree[idx].childs_id[7] <<
")";
321 std::cout <<
" (" << serial_tree[idx].box_min.x() <<
", "
322 << serial_tree[idx].box_min.y() <<
", " << serial_tree[idx].box_min.z()
325 std::cout <<
" (" << serial_tree[idx].box_max.x() <<
", "
326 << serial_tree[idx].box_max.y() <<
", " << serial_tree[idx].box_max.z()
329 std::cout <<
" = " << linked_patch_ids[idx];
331 std::cout << std::endl;
335 sycl::buffer<u64> compute_patch_owner(
336 sham::DeviceScheduler_ptr dev_sched,
344 sycl::buffer<u64> new_owned_id(len);
346 using namespace shamrock::patch;
350 auto &q = dev_sched->get_queue();
355 auto e = q.submit(depends_list, [&](sycl::handler &cgh) {
357 sycl::accessor linked_node_id{
359 sycl::accessor roots_id{roots, cgh, sycl::read_only};
360 sycl::accessor new_id{new_owned_id, cgh, sycl::write_only, sycl::no_init};
362 u32 root_cnt = roots_id.size();
363 auto max_lev = get_level_count();
367 cgh.parallel_for(sycl::range(len), [=](sycl::item<1> item) {
368 u32 i = (
u32) item.get_id(0);
372 u64 current_node = 0;
375 for (
u32 iroot = 0; iroot < root_cnt; iroot++) {
376 u32 root_id = roots_id[iroot];
377 PtNode root_node = tnode[root_id];
379 if (Patch::is_in_patch_converted(xyz, root_node.box_min, root_node.box_max)) {
380 current_node = root_id;
387 for (
u32 step = 0; step < max_lev + 1; step++) {
388 PtNode cur_node = tnode[current_node];
390 if (cur_node.childs_id[0] !=
u64_max) {
392 if (Patch::is_in_patch_converted(
394 tnode[cur_node.childs_id[0]].box_min,
395 tnode[cur_node.childs_id[0]].box_max)) {
396 current_node = cur_node.childs_id[0];
398 Patch::is_in_patch_converted(
400 tnode[cur_node.childs_id[1]].box_min,
401 tnode[cur_node.childs_id[1]].box_max)) {
402 current_node = cur_node.childs_id[1];
404 Patch::is_in_patch_converted(
406 tnode[cur_node.childs_id[2]].box_min,
407 tnode[cur_node.childs_id[2]].box_max)) {
408 current_node = cur_node.childs_id[2];
410 Patch::is_in_patch_converted(
412 tnode[cur_node.childs_id[3]].box_min,
413 tnode[cur_node.childs_id[3]].box_max)) {
414 current_node = cur_node.childs_id[3];
416 Patch::is_in_patch_converted(
418 tnode[cur_node.childs_id[4]].box_min,
419 tnode[cur_node.childs_id[4]].box_max)) {
420 current_node = cur_node.childs_id[4];
422 Patch::is_in_patch_converted(
424 tnode[cur_node.childs_id[5]].box_min,
425 tnode[cur_node.childs_id[5]].box_max)) {
426 current_node = cur_node.childs_id[5];
428 Patch::is_in_patch_converted(
430 tnode[cur_node.childs_id[6]].box_min,
431 tnode[cur_node.childs_id[6]].box_max)) {
432 current_node = cur_node.childs_id[6];
434 Patch::is_in_patch_converted(
436 tnode[cur_node.childs_id[7]].box_min,
437 tnode[cur_node.childs_id[7]].box_max)) {
438 current_node = cur_node.childs_id[7];
443 result_node = linked_node_id[current_node];
448 if constexpr (
false) {
449 PtNode cur_node = tnode[current_node];
450 if (xyz[0] == 0 && xyz[1] == 0 && xyz[2] == 0) {
453 "{:5} ({}) -> {} [{} {}]\n",
455 Patch::is_in_patch_converted(xyz, cur_node.box_min, cur_node.box_max),
462 new_id[i] = result_node;
constexpr const char * xyz
Position field (3D coordinates)
std::uint32_t u32
32 bit unsigned integer
std::uint64_t u64
64 bit unsigned integer
PatchTree patch_tree
handle the tree structure of the patches
SchedulerPatchList patch_list
handle the list of the patches of the scheduler
std::unordered_map< u64, u64 > id_patch_to_global_idx
id_patch_to_global_idx[patch_id] = index in global patch list
const u32 & get_level_count()
accesor to the number of level in the tree
u32 get_element_count()
accesor to the number of element in the tree
Define a field attached to a patch (example: FMM multipoles, hmax in SPH)
A buffer allocated in USM (Unified Shared Memory)
void complete_event_state(sycl::event e) const
Complete the event state of the buffer.
const T * get_read_access(sham::EventList &depends_list, SourceLocation src_loc=SourceLocation{}) const
Get a read-only pointer to the buffer's data.
Class to manage a list of SYCL events.
Patch Tree : Tree structure organisation for an abstract list of patches Nb : this tree is compatible...
sycl::buffer< T > vec_to_buf(const std::vector< T > &buf)
Convert a std::vector to a sycl::buffer
void throw_with_loc(std::string message, SourceLocation loc=SourceLocation{})
Throw an exception and append the source location to it.
T & get_check_ref(const std::unique_ptr< T > &ptr, SourceLocation loc=SourceLocation())
Takes a std::unique_ptr and returns a reference to the object it holds. It throws a std::runtime_erro...
i32 world_rank()
Gives the rank of the current process in the MPI communicator.
constexpr u64 u64_max
u64 max value
This file contains the definition for the stacktrace related functionality.
header file to manage sycl