Shamrock 2025.10.0
Astrophysical Code
Loading...
Searching...
No Matches
SerialPatchTree.hpp
Go to the documentation of this file.
1// -------------------------------------------------------//
2//
3// SHAMROCK code for hydrodynamics
4// Copyright (c) 2021-2026 Timothée David--Cléris <tim.shamrock@proton.me>
5// SPDX-License-Identifier: CeCILL Free Software License Agreement v2.1
6// Shamrock is licensed under the CeCILL 2.1 License, see LICENSE for more information
7//
8// -------------------------------------------------------//
9
10#pragma once
11
19//%Impl status : Should rewrite
20
21#include "shambase/memory.hpp"
30#include <array>
31#include <tuple>
32#include <vector>
33
34template<class fp_prec_vec>
36 public:
38
40
41 // TODO use unique pointer instead
42 u32 root_count = 0;
43 std::unique_ptr<sycl::buffer<PtNode>> serial_tree_buf;
44 std::unique_ptr<sycl::buffer<u64>> linked_patch_ids_buf;
45
46 inline void attach_buf() {
47 if (bool(serial_tree_buf))
49 "serial_tree_buf is already allocated");
50 if (bool(linked_patch_ids_buf))
52 "linked_patch_ids_buf is already allocated");
53
54 serial_tree_buf
55 = std::make_unique<sycl::buffer<PtNode>>(serial_tree.data(), serial_tree.size());
56 linked_patch_ids_buf
57 = std::make_unique<sycl::buffer<u64>>(linked_patch_ids.data(), linked_patch_ids.size());
58 }
59
60 inline void detach_buf() {
61 if (!bool(serial_tree_buf))
63 "serial_tree_buf wasn't allocated");
64 if (!bool(linked_patch_ids_buf))
66 "linked_patch_ids_buf wasn't allocated");
67
68 serial_tree_buf.reset();
69 linked_patch_ids_buf.reset();
70 }
71
72 private:
73 u32 level_count = 0;
74
75 std::vector<PtNode> serial_tree;
76 std::vector<u64> linked_patch_ids;
77 std::vector<u64> roots_ids;
78
79 void build_from_patch_tree(
81
82 public:
83 inline void print_status() {
84 if (shamcomm::world_rank() == 0) {
85 for (PtNode n : serial_tree) {
86 logger::raw_ln(
87 n.box_min,
88 n.box_max,
89 "[",
90 n.childs_id[0],
91 n.childs_id[1],
92 n.childs_id[2],
93 n.childs_id[3],
94 n.childs_id[4],
95 n.childs_id[5],
96 n.childs_id[6],
97 n.childs_id[7],
98 "]");
99 }
100 }
101 }
102
103 inline SerialPatchTree(
105 StackEntry stack_loc{};
106 build_from_patch_tree(ptree, box_transform);
107 }
108
109 template<class Acc1, class Acc2>
110 inline void host_for_each_leafs_internal(
111 std::function<bool(u64, PtNode pnode)> interact_cd,
112 std::function<void(u64, PtNode)> found_case,
113 Acc1 &&tree,
114 Acc2 &&lpid) {
115
116 std::stack<u64> id_stack;
117
118 for (u64 root : roots_ids) {
119 id_stack.push(root);
120 }
121
122 while (!id_stack.empty()) {
123 u64 cur_id = id_stack.top();
124 id_stack.pop();
125 PtNode cur_p = tree[cur_id];
126
127 bool interact = interact_cd(cur_id, cur_p);
128
129 if (interact) {
130 u64 linked_id = lpid[cur_id];
131 if (linked_id != u64_max) {
132 found_case(linked_id, cur_p);
133 } else {
134 id_stack.push(cur_p.childs_id[0]);
135 id_stack.push(cur_p.childs_id[1]);
136 id_stack.push(cur_p.childs_id[2]);
137 id_stack.push(cur_p.childs_id[3]);
138 id_stack.push(cur_p.childs_id[4]);
139 id_stack.push(cur_p.childs_id[5]);
140 id_stack.push(cur_p.childs_id[6]);
141 id_stack.push(cur_p.childs_id[7]);
142 }
143 }
144 }
145 }
146
147 inline void host_for_each_leafs(
148 std::function<bool(u64, PtNode pnode)> interact_cd,
149 std::function<void(u64, PtNode)> found_case) {
150 StackEntry stack_loc{false};
151
152 sycl::host_accessor tree{shambase::get_check_ref(serial_tree_buf), sycl::read_only};
153 sycl::host_accessor lpid{shambase::get_check_ref(linked_patch_ids_buf), sycl::read_only};
154
155 host_for_each_leafs_internal(interact_cd, found_case, tree, lpid);
156 }
157
163 inline const u32 &get_level_count() { return level_count; }
164
170 inline u32 get_element_count() { return serial_tree.size(); }
171
172 inline static SerialPatchTree<fp_prec_vec> build(PatchScheduler &sched) {
174 sched.patch_tree, sched.get_patch_transform<fp_prec_vec>());
175 }
176
177 template<class type, class reduc_func>
178 inline PatchFieldReduction<type> reduce_field(
179 sycl::queue &queue, PatchScheduler &sched, legacy::PatchField<type> &pfield) {
180
182
183 std::cout << "resize to " << get_element_count() << std::endl;
184 predfield.tree_field.resize(get_element_count());
185
186 {
187 sycl::host_accessor lpid{*linked_patch_ids_buf, sycl::read_only};
188
189 // init reduction
190 std::unordered_map<u64, u64> &idp_to_gid = sched.patch_list.id_patch_to_global_idx;
191 for (u64 idx = 0; idx < get_element_count(); idx++) {
192 predfield.tree_field[idx]
193 = (lpid[idx] != u64_max) ? pfield.global_values[idp_to_gid[lpid[idx]]] : type();
194
195 // std::cout << " el " << idx << " " << predfield.tree_field[idx] << std::endl;
196 }
197 }
198
199 // std::cout << "predfield.attach_buf();" << std::endl;
200
201 predfield.attach_buf();
202
203 sycl::range<1> range{get_element_count()};
204
205 u32 end_loop = get_level_count();
206
207 for (u32 level = 0; level < end_loop; level++) {
208
209 // {
210 // auto f = predfield.tree_field_buf->template
211 // get_access<sycl::access::mode::read>(); std::cout << "["; for (u64 idx = 0; idx <
212 // get_element_count() ; idx ++) {
213 // std::cout << f[idx] << ",";
214 // }
215 // std::cout << std::endl;
216 // }
217
218 std::cout << "queue submit : " << level << " " << end_loop << " " << (level < end_loop)
219 << std::endl;
220 queue.submit([&](sycl::handler &cgh) {
221 auto tree
222 = this->serial_tree_buf->template get_access<sycl::access::mode::read>(cgh);
223
224 auto f
225 = predfield.tree_field_buf->template get_access<sycl::access::mode::read_write>(
226 cgh);
227
228 cgh.parallel_for<class OctreeReduction>(range, [=](sycl::item<1> item) {
229 u64 i = (u64) item.get_id(0);
230
231 u64 idx_c0 = tree[i].childs_id0;
232 u64 idx_c1 = tree[i].childs_id1;
233 u64 idx_c2 = tree[i].childs_id2;
234 u64 idx_c3 = tree[i].childs_id3;
235 u64 idx_c4 = tree[i].childs_id4;
236 u64 idx_c5 = tree[i].childs_id5;
237 u64 idx_c6 = tree[i].childs_id6;
238 u64 idx_c7 = tree[i].childs_id7;
239
240 if (idx_c0 != u64_max) {
241 f[i] = reduc_func::reduce(
242 f[idx_c0],
243 f[idx_c1],
244 f[idx_c2],
245 f[idx_c3],
246 f[idx_c4],
247 f[idx_c5],
248 f[idx_c6],
249 f[idx_c7]);
250 }
251 });
252 });
253 }
254 // {
255 // auto f = predfield.tree_field_buf->template get_access<sycl::access::mode::read>();
256 // std::cout << "[";
257 // for (u64 idx = 0; idx < get_element_count() ; idx ++) {
258 // std::cout << f[idx] << ",";
259 // }
260 // std::cout << std::endl;
261 // }
262
263 return predfield;
264 }
265
266 template<class T, class Func>
267 inline shamrock::patch::PatchtreeField<T> make_patch_tree_field(
268 PatchScheduler &sched,
269 sycl::queue &queue,
271 Func &&reducer) {
273 ptfield.allocate(get_element_count());
274
275 {
276 sycl::host_accessor lpid{
277 shambase::get_check_ref(linked_patch_ids_buf), sycl::read_only};
278 sycl::host_accessor tree_field{
279 shambase::get_check_ref(ptfield.internal_buf), sycl::write_only, sycl::no_init};
280
281 // init reduction
282 std::unordered_map<u64, u64> &idp_to_gid = sched.patch_list.id_patch_to_global_idx;
283 for (u64 idx = 0; idx < get_element_count(); idx++) {
284 tree_field[idx] = (lpid[idx] != u64_max) ? pfield.get(lpid[idx]) : T();
285 }
286 }
287
288 sycl::range<1> range{get_element_count()};
289 u32 end_loop = get_level_count();
290
291 for (u32 level = 0; level < end_loop; level++) {
292 queue.submit([&](sycl::handler &cgh) {
293 sycl::accessor tree{shambase::get_check_ref(serial_tree_buf), cgh, sycl::read_only};
294 sycl::accessor f{
295 shambase::get_check_ref(ptfield.internal_buf), cgh, sycl::read_write};
296
297 cgh.parallel_for(range, [=](sycl::item<1> item) {
298 u64 i = (u64) item.get_id(0);
299
300 std::array<u64, 8> n = tree[i].childs_id;
301
302 if (n[0] != u64_max) {
303 f[i] = reducer(
304 f[n[0]], f[n[1]], f[n[2]], f[n[3]], f[n[4]], f[n[5]], f[n[6]], f[n[7]]);
305 }
306 });
307 });
308 }
309 return ptfield;
310 }
311
312 inline void dump_dat() {
313 for (u64 idx = 0; idx < get_element_count(); idx++) {
314 std::cout << idx << " (" << serial_tree[idx].childs_id[0] << ", "
315 << serial_tree[idx].childs_id[1] << ", " << serial_tree[idx].childs_id[2]
316 << ", " << serial_tree[idx].childs_id[3] << ", "
317 << serial_tree[idx].childs_id[4] << ", " << serial_tree[idx].childs_id[5]
318 << ", " << serial_tree[idx].childs_id[6] << ", "
319 << serial_tree[idx].childs_id[7] << ")";
320
321 std::cout << " (" << serial_tree[idx].box_min.x() << ", "
322 << serial_tree[idx].box_min.y() << ", " << serial_tree[idx].box_min.z()
323 << ")";
324
325 std::cout << " (" << serial_tree[idx].box_max.x() << ", "
326 << serial_tree[idx].box_max.y() << ", " << serial_tree[idx].box_max.z()
327 << ")";
328
329 std::cout << " = " << linked_patch_ids[idx];
330
331 std::cout << std::endl;
332 }
333 }
334
335 sycl::buffer<u64> compute_patch_owner(
336 sham::DeviceScheduler_ptr dev_sched,
337 sham::DeviceBuffer<fp_prec_vec> &position_buffer,
338 u32 len);
339};
340
341template<class vec>
343 sham::DeviceScheduler_ptr dev_sched, sham::DeviceBuffer<vec> &position_buffer, u32 len) {
344 sycl::buffer<u64> new_owned_id(len);
345
346 using namespace shamrock::patch;
347
348 sycl::buffer<u64> roots = shamalgs::vec_to_buf(roots_ids);
349
350 auto &q = dev_sched->get_queue();
351
352 sham::EventList depends_list;
353 auto pos = position_buffer.get_read_access(depends_list);
354
355 auto e = q.submit(depends_list, [&](sycl::handler &cgh) {
356 sycl::accessor tnode{shambase::get_check_ref(serial_tree_buf), cgh, sycl::read_only};
357 sycl::accessor linked_node_id{
358 shambase::get_check_ref(linked_patch_ids_buf), cgh, sycl::read_only};
359 sycl::accessor roots_id{roots, cgh, sycl::read_only};
360 sycl::accessor new_id{new_owned_id, cgh, sycl::write_only, sycl::no_init};
361
362 u32 root_cnt = roots_id.size();
363 auto max_lev = get_level_count();
364
366
367 cgh.parallel_for(sycl::range(len), [=](sycl::item<1> item) {
368 u32 i = (u32) item.get_id(0);
369
370 auto xyz = pos[i];
371
372 u64 current_node = 0;
373
374 // find the correct root to start the search
375 for (u32 iroot = 0; iroot < root_cnt; iroot++) {
376 u32 root_id = roots_id[iroot];
377 PtNode root_node = tnode[root_id];
378
379 if (Patch::is_in_patch_converted(xyz, root_node.box_min, root_node.box_max)) {
380 current_node = root_id;
381 break;
382 }
383 }
384
385 u64 result_node = u64_max;
386
387 for (u32 step = 0; step < max_lev + 1; step++) {
388 PtNode cur_node = tnode[current_node];
389
390 if (cur_node.childs_id[0] != u64_max) {
391
392 if (Patch::is_in_patch_converted(
393 xyz,
394 tnode[cur_node.childs_id[0]].box_min,
395 tnode[cur_node.childs_id[0]].box_max)) {
396 current_node = cur_node.childs_id[0];
397 } else if (
398 Patch::is_in_patch_converted(
399 xyz,
400 tnode[cur_node.childs_id[1]].box_min,
401 tnode[cur_node.childs_id[1]].box_max)) {
402 current_node = cur_node.childs_id[1];
403 } else if (
404 Patch::is_in_patch_converted(
405 xyz,
406 tnode[cur_node.childs_id[2]].box_min,
407 tnode[cur_node.childs_id[2]].box_max)) {
408 current_node = cur_node.childs_id[2];
409 } else if (
410 Patch::is_in_patch_converted(
411 xyz,
412 tnode[cur_node.childs_id[3]].box_min,
413 tnode[cur_node.childs_id[3]].box_max)) {
414 current_node = cur_node.childs_id[3];
415 } else if (
416 Patch::is_in_patch_converted(
417 xyz,
418 tnode[cur_node.childs_id[4]].box_min,
419 tnode[cur_node.childs_id[4]].box_max)) {
420 current_node = cur_node.childs_id[4];
421 } else if (
422 Patch::is_in_patch_converted(
423 xyz,
424 tnode[cur_node.childs_id[5]].box_min,
425 tnode[cur_node.childs_id[5]].box_max)) {
426 current_node = cur_node.childs_id[5];
427 } else if (
428 Patch::is_in_patch_converted(
429 xyz,
430 tnode[cur_node.childs_id[6]].box_min,
431 tnode[cur_node.childs_id[6]].box_max)) {
432 current_node = cur_node.childs_id[6];
433 } else if (
434 Patch::is_in_patch_converted(
435 xyz,
436 tnode[cur_node.childs_id[7]].box_min,
437 tnode[cur_node.childs_id[7]].box_max)) {
438 current_node = cur_node.childs_id[7];
439 }
440
441 } else {
442
443 result_node = linked_node_id[current_node];
444 break;
445 }
446 }
447
448 if constexpr (false) {
449 PtNode cur_node = tnode[current_node];
450 if (xyz[0] == 0 && xyz[1] == 0 && xyz[2] == 0) {
451 logger::raw(
452 shambase::format(
453 "{:5} ({}) -> {} [{} {}]\n",
454 i,
455 Patch::is_in_patch_converted(xyz, cur_node.box_min, cur_node.box_max),
456 xyz,
457 cur_node.box_min,
458 cur_node.box_max));
459 }
460 }
461
462 new_id[i] = result_node;
463 });
464 });
465
466 position_buffer.complete_event_state(e);
467
468 return new_owned_id;
469}
constexpr const char * xyz
Position field (3D coordinates)
MPI scheduler.
std::uint32_t u32
32 bit unsigned integer
std::uint64_t u64
64 bit unsigned integer
The MPI scheduler.
PatchTree patch_tree
handle the tree structure of the patches
SchedulerPatchList patch_list
handle the list of the patches of the scheduler
std::unordered_map< u64, u64 > id_patch_to_global_idx
id_patch_to_global_idx[patch_id] = index in global patch list
const u32 & get_level_count()
accesor to the number of level in the tree
u32 get_element_count()
accesor to the number of element in the tree
Define a field attached to a patch (example: FMM multipoles, hmax in SPH)
A buffer allocated in USM (Unified Shared Memory)
void complete_event_state(sycl::event e) const
Complete the event state of the buffer.
const T * get_read_access(sham::EventList &depends_list, SourceLocation src_loc=SourceLocation{}) const
Get a read-only pointer to the buffer's data.
Class to manage a list of SYCL events.
Definition EventList.hpp:31
Patch Tree : Tree structure organisation for an abstract list of patches Nb : this tree is compatible...
Definition PatchTree.hpp:29
sycl::buffer< T > vec_to_buf(const std::vector< T > &buf)
Convert a std::vector to a sycl::buffer
Definition memory.cpp:29
void throw_with_loc(std::string message, SourceLocation loc=SourceLocation{})
Throw an exception and append the source location to it.
T & get_check_ref(const std::unique_ptr< T > &ptr, SourceLocation loc=SourceLocation())
Takes a std::unique_ptr and returns a reference to the object it holds. It throws a std::runtime_erro...
Definition memory.hpp:110
i32 world_rank()
Gives the rank of the current process in the MPI communicator.
Definition worldInfo.cpp:40
constexpr u64 u64_max
u64 max value
This file contains the definition for the stacktrace related functionality.
header file to manage sycl