Shamrock 2025.10.0
Astrophysical Code
Loading...
Searching...
No Matches
SerialPatchTree.hpp
Go to the documentation of this file.
1// -------------------------------------------------------//
2//
3// SHAMROCK code for hydrodynamics
4// Copyright (c) 2021-2026 Timothée David--Cléris <tim.shamrock@proton.me>
5// SPDX-License-Identifier: CeCILL Free Software License Agreement v2.1
6// Shamrock is licensed under the CeCILL 2.1 License, see LICENSE for more information
7//
8// -------------------------------------------------------//
9
10#pragma once
11
19
20//%Impl status : Should rewrite
21
22#include "shambase/memory.hpp"
30#include <array>
31#include <tuple>
32#include <vector>
33
34template<class fp_prec_vec>
35class SerialPatchTree {
36 public:
38
39 using PatchTree = shamrock::scheduler::PatchTree;
40
41 // TODO use unique pointer instead
42 u32 root_count = 0;
43 std::unique_ptr<sycl::buffer<PtNode>> serial_tree_buf;
44 std::unique_ptr<sycl::buffer<u64>> linked_patch_ids_buf;
45
46 inline void attach_buf() {
47 if (bool(serial_tree_buf))
49 "serial_tree_buf is already allocated");
50 if (bool(linked_patch_ids_buf))
52 "linked_patch_ids_buf is already allocated");
53
54 serial_tree_buf
55 = std::make_unique<sycl::buffer<PtNode>>(serial_tree.data(), serial_tree.size());
56 linked_patch_ids_buf
57 = std::make_unique<sycl::buffer<u64>>(linked_patch_ids.data(), linked_patch_ids.size());
58 }
59
60 inline void detach_buf() {
61 if (!bool(serial_tree_buf))
63 "serial_tree_buf wasn't allocated");
64 if (!bool(linked_patch_ids_buf))
66 "linked_patch_ids_buf wasn't allocated");
67
68 serial_tree_buf.reset();
69 linked_patch_ids_buf.reset();
70 }
71
72 private:
73 u32 level_count = 0;
74
75 std::vector<PtNode> serial_tree;
76 std::vector<u64> linked_patch_ids;
77 std::vector<u64> roots_ids;
78
79 void build_from_patch_tree(
80 PatchTree &ptree, const shamrock::patch::PatchCoordTransform<fp_prec_vec> box_transform);
81
82 public:
83 inline void print_status() {
84 if (shamcomm::world_rank() == 0) {
85 for (PtNode n : serial_tree) {
87 n.box_min,
88 n.box_max,
89 "[",
90 n.childs_id[0],
91 n.childs_id[1],
92 n.childs_id[2],
93 n.childs_id[3],
94 n.childs_id[4],
95 n.childs_id[5],
96 n.childs_id[6],
97 n.childs_id[7],
98 "]");
99 }
100 }
101 }
102
103 inline SerialPatchTree(
104 PatchTree &ptree, const shamrock::patch::PatchCoordTransform<fp_prec_vec> box_transform) {
105 StackEntry stack_loc{};
106 build_from_patch_tree(ptree, box_transform);
107 }
108
109 template<class Acc1, class Acc2>
110 inline void host_for_each_leafs_internal(
111 std::function<bool(u64, PtNode pnode)> interact_cd,
112 std::function<void(u64, PtNode)> found_case,
113 Acc1 &&tree,
114 Acc2 &&lpid) {
115
116 std::stack<u64> id_stack;
117
118 for (u64 root : roots_ids) {
119 id_stack.push(root);
120 }
121
122 while (!id_stack.empty()) {
123 u64 cur_id = id_stack.top();
124 id_stack.pop();
125 PtNode cur_p = tree[cur_id];
126
127 bool interact = interact_cd(cur_id, cur_p);
128
129 if (interact) {
130 u64 linked_id = lpid[cur_id];
131 if (linked_id != u64_max) {
132 found_case(linked_id, cur_p);
133 } else {
134 id_stack.push(cur_p.childs_id[0]);
135 id_stack.push(cur_p.childs_id[1]);
136 id_stack.push(cur_p.childs_id[2]);
137 id_stack.push(cur_p.childs_id[3]);
138 id_stack.push(cur_p.childs_id[4]);
139 id_stack.push(cur_p.childs_id[5]);
140 id_stack.push(cur_p.childs_id[6]);
141 id_stack.push(cur_p.childs_id[7]);
142 }
143 }
144 }
145 }
146
147 inline void host_for_each_leafs(
148 std::function<bool(u64, PtNode pnode)> interact_cd,
149 std::function<void(u64, PtNode)> found_case) {
150 StackEntry stack_loc{false};
151
152 sycl::host_accessor tree{shambase::get_check_ref(serial_tree_buf), sycl::read_only};
153 sycl::host_accessor lpid{shambase::get_check_ref(linked_patch_ids_buf), sycl::read_only};
154
155 host_for_each_leafs_internal(interact_cd, found_case, tree, lpid);
156 }
157
163 inline const u32 &get_level_count() { return level_count; }
164
170 inline u32 get_element_count() { return serial_tree.size(); }
171
172 inline static SerialPatchTree<fp_prec_vec> build(PatchScheduler &sched) {
174 sched.patch_tree, sched.get_patch_transform<fp_prec_vec>());
175 }
176
177 template<class T, class Func>
178 inline shamrock::patch::PatchtreeField<T> make_patch_tree_field(
179 PatchScheduler &sched,
180 sycl::queue &queue,
181 shamrock::patch::PatchField<T> pfield,
182 Func &&reducer) {
183 shamrock::patch::PatchtreeField<T> ptfield;
184 ptfield.allocate(get_element_count());
185
186 {
187 sycl::host_accessor lpid{
188 shambase::get_check_ref(linked_patch_ids_buf), sycl::read_only};
189 sycl::host_accessor tree_field{
190 shambase::get_check_ref(ptfield.internal_buf), sycl::write_only, sycl::no_init};
191
192 // init reduction
193 std::unordered_map<u64, u64> &idp_to_gid = sched.patch_list.id_patch_to_global_idx;
194 for (u64 idx = 0; idx < get_element_count(); idx++) {
195 tree_field[idx] = (lpid[idx] != u64_max) ? pfield.get(lpid[idx]) : T();
196 }
197 }
198
199 sycl::range<1> range{get_element_count()};
200 u32 end_loop = get_level_count();
201
202 for (u32 level = 0; level < end_loop; level++) {
203 queue.submit([&](sycl::handler &cgh) {
204 sycl::accessor tree{shambase::get_check_ref(serial_tree_buf), cgh, sycl::read_only};
205 sycl::accessor f{
206 shambase::get_check_ref(ptfield.internal_buf), cgh, sycl::read_write};
207
208 cgh.parallel_for(range, [=](sycl::item<1> item) {
209 u64 i = (u64) item.get_id(0);
210
211 std::array<u64, 8> n = tree[i].childs_id;
212
213 if (n[0] != u64_max) {
214 f[i] = reducer(
215 f[n[0]], f[n[1]], f[n[2]], f[n[3]], f[n[4]], f[n[5]], f[n[6]], f[n[7]]);
216 }
217 });
218 });
219 }
220 return ptfield;
221 }
222
223 inline void dump_dat() {
224 for (u64 idx = 0; idx < get_element_count(); idx++) {
225 std::cout << idx << " (" << serial_tree[idx].childs_id[0] << ", "
226 << serial_tree[idx].childs_id[1] << ", " << serial_tree[idx].childs_id[2]
227 << ", " << serial_tree[idx].childs_id[3] << ", "
228 << serial_tree[idx].childs_id[4] << ", " << serial_tree[idx].childs_id[5]
229 << ", " << serial_tree[idx].childs_id[6] << ", "
230 << serial_tree[idx].childs_id[7] << ")";
231
232 std::cout << " (" << serial_tree[idx].box_min.x() << ", "
233 << serial_tree[idx].box_min.y() << ", " << serial_tree[idx].box_min.z()
234 << ")";
235
236 std::cout << " (" << serial_tree[idx].box_max.x() << ", "
237 << serial_tree[idx].box_max.y() << ", " << serial_tree[idx].box_max.z()
238 << ")";
239
240 std::cout << " = " << linked_patch_ids[idx];
241
242 std::cout << std::endl;
243 }
244 }
245
246 sycl::buffer<u64> compute_patch_owner(
247 sham::DeviceScheduler_ptr dev_sched,
248 sham::DeviceBuffer<fp_prec_vec> &position_buffer,
249 u32 len);
250};
251
252template<class vec>
253sycl::buffer<u64> SerialPatchTree<vec>::compute_patch_owner(
254 sham::DeviceScheduler_ptr dev_sched, sham::DeviceBuffer<vec> &position_buffer, u32 len) {
255 sycl::buffer<u64> new_owned_id(len);
256
257 using namespace shamrock::patch;
258
259 sycl::buffer<u64> roots = shamalgs::vec_to_buf(roots_ids);
260
261 auto &q = dev_sched->get_queue();
262
263 sham::EventList depends_list;
264 auto pos = position_buffer.get_read_access(depends_list);
265
266 auto e = q.submit(depends_list, [&](sycl::handler &cgh) {
267 sycl::accessor tnode{shambase::get_check_ref(serial_tree_buf), cgh, sycl::read_only};
268 sycl::accessor linked_node_id{
269 shambase::get_check_ref(linked_patch_ids_buf), cgh, sycl::read_only};
270 sycl::accessor roots_id{roots, cgh, sycl::read_only};
271 sycl::accessor new_id{new_owned_id, cgh, sycl::write_only, sycl::no_init};
272
273 u32 root_cnt = roots_id.size();
274 auto max_lev = get_level_count();
275
277
278 cgh.parallel_for(sycl::range(len), [=](sycl::item<1> item) {
279 u32 i = (u32) item.get_id(0);
280
281 auto xyz = pos[i];
282
283 u64 current_node = 0;
284
285 // find the correct root to start the search
286 for (u32 iroot = 0; iroot < root_cnt; iroot++) {
287 u32 root_id = roots_id[iroot];
288 PtNode root_node = tnode[root_id];
289
290 if (Patch::is_in_patch_converted(xyz, root_node.box_min, root_node.box_max)) {
291 current_node = root_id;
292 break;
293 }
294 }
295
296 u64 result_node = u64_max;
297
298 for (u32 step = 0; step < max_lev + 1; step++) {
299 PtNode cur_node = tnode[current_node];
300
301 if (cur_node.childs_id[0] != u64_max) {
302
304 xyz,
305 tnode[cur_node.childs_id[0]].box_min,
306 tnode[cur_node.childs_id[0]].box_max)) {
307 current_node = cur_node.childs_id[0];
308 } else if (
310 xyz,
311 tnode[cur_node.childs_id[1]].box_min,
312 tnode[cur_node.childs_id[1]].box_max)) {
313 current_node = cur_node.childs_id[1];
314 } else if (
316 xyz,
317 tnode[cur_node.childs_id[2]].box_min,
318 tnode[cur_node.childs_id[2]].box_max)) {
319 current_node = cur_node.childs_id[2];
320 } else if (
322 xyz,
323 tnode[cur_node.childs_id[3]].box_min,
324 tnode[cur_node.childs_id[3]].box_max)) {
325 current_node = cur_node.childs_id[3];
326 } else if (
328 xyz,
329 tnode[cur_node.childs_id[4]].box_min,
330 tnode[cur_node.childs_id[4]].box_max)) {
331 current_node = cur_node.childs_id[4];
332 } else if (
334 xyz,
335 tnode[cur_node.childs_id[5]].box_min,
336 tnode[cur_node.childs_id[5]].box_max)) {
337 current_node = cur_node.childs_id[5];
338 } else if (
340 xyz,
341 tnode[cur_node.childs_id[6]].box_min,
342 tnode[cur_node.childs_id[6]].box_max)) {
343 current_node = cur_node.childs_id[6];
344 } else if (
346 xyz,
347 tnode[cur_node.childs_id[7]].box_min,
348 tnode[cur_node.childs_id[7]].box_max)) {
349 current_node = cur_node.childs_id[7];
350 }
351
352 } else {
353
354 result_node = linked_node_id[current_node];
355 break;
356 }
357 }
358
359 new_id[i] = result_node;
360 });
361 });
362
363 position_buffer.complete_event_state(e);
364
365 return new_owned_id;
366}
constexpr const char * xyz
Position field (3D coordinates).
MPI scheduler.
std::uint32_t u32
32 bit unsigned integer
std::uint64_t u64
64 bit unsigned integer
The MPI scheduler.
PatchTree patch_tree
handle the tree structure of the patches
SchedulerPatchList patch_list
handle the list of the patches of the scheduler
std::unordered_map< u64, u64 > id_patch_to_global_idx
id_patch_to_global_idx[patch_id] = index in global patch list
const u32 & get_level_count()
accesor to the number of level in the tree
u32 get_element_count()
accesor to the number of element in the tree
A buffer allocated in USM (Unified Shared Memory).
void complete_event_state(sycl::event e) const
Complete the event state of the buffer.
const T * get_read_access(sham::EventList &depends_list, SourceLocation src_loc=SourceLocation{}) const
Get a read-only pointer to the buffer's data.
Class to manage a list of SYCL events.
Definition EventList.hpp:31
Patch Tree : Tree structure organisation for an abstract list of patches Nb : this tree is compatible...
Definition PatchTree.hpp:29
sycl::buffer< T > vec_to_buf(const std::vector< T > &buf)
Convert a std::vector to a sycl::buffer.
Definition memory.cpp:29
T & get_check_ref(const std::unique_ptr< T > &ptr, SourceLocation loc=SourceLocation())
Takes a std::unique_ptr and returns a reference to the object it holds. It throws a std::runtime_erro...
Definition memory.hpp:110
ExcptTypes make_except_with_loc(std::string message, SourceLocation loc=SourceLocation{})
Create an exception with a message and a location.
i32 world_rank()
Gives the rank of the current process in the MPI communicator.
Definition worldInfo.cpp:40
constexpr u64 u64_max
u64 max value
void raw_ln(Types... var2)
Prints a log message with multiple arguments followed by a newline.
Definition logs.hpp:90
This file contains the definition for the stacktrace related functionality.
shambase::details::BasicStackEntry StackEntry
Alias for shambase::details::BasicStackEntry.
static bool is_in_patch_converted(sycl::vec< T, 3 > val, sycl::vec< T, 3 > min_val, sycl::vec< T, 3 > max_val)
check if particle is in the asked range, given the output of @convert_coord
Definition Patch.hpp:210
header file to manage sycl