Shamrock 2025.10.0
Astrophysical Code
Loading...
Searching...
No Matches
AMROverheadtest.hpp
Go to the documentation of this file.
1// -------------------------------------------------------//
2//
3// SHAMROCK code for hydrodynamics
4// Copyright (c) 2021-2026 Timothée David--Cléris <tim.shamrock@proton.me>
5// SPDX-License-Identifier: CeCILL Free Software License Agreement v2.1
6// Shamrock is licensed under the CeCILL 2.1 License, see LICENSE for more information
7//
8// -------------------------------------------------------//
9
10#pragma once
11
19#include "shambase/time.hpp"
20#include "shamalgs/memory.hpp"
29#include <vector>
30
32 public:
34 Grid &grid;
35
36 explicit AMRTestModel(Grid &grd) : grid(grd) {}
37
39 public:
40 const u64_3 *cell_low_bound;
41 const u64_3 *cell_high_bound;
42
44 sham::EventList &depends_list,
45 u64 id_patch,
48
49 sham::DeviceBuffer<u64_3> &buf_cell_low_bound = pdat.get_field<u64_3>(0).get_buf();
50 sham::DeviceBuffer<u64_3> &buf_cell_high_bound = pdat.get_field<u64_3>(1).get_buf();
51 cell_low_bound = buf_cell_low_bound.get_read_access(depends_list);
52 cell_high_bound = buf_cell_high_bound.get_read_access(depends_list);
53 }
54
55 void finalize(
56 sham::EventList &resulting_events,
57 u64 id_patch,
60
61 sham::DeviceBuffer<u64_3> &buf_cell_low_bound = pdat.get_field<u64_3>(0).get_buf();
62 sham::DeviceBuffer<u64_3> &buf_cell_high_bound = pdat.get_field<u64_3>(1).get_buf();
63
64 buf_cell_low_bound.complete_event_state(resulting_events);
65 buf_cell_high_bound.complete_event_state(resulting_events);
66 }
67 };
68
69 template<class T>
70 using buf_access_read = sycl::accessor<T, 1, sycl::access::mode::read, sycl::target::device>;
71 template<class T>
72 using buf_access_read_write
73 = sycl::accessor<T, 1, sycl::access::mode::read_write, sycl::target::device>;
74
76 public:
77 u32 *field;
78
80
81 auto &buf_field = pdat.get_field<u32>(2).get_buf();
82 field = buf_field.get_write_access(depends_list);
83 }
84
85 void finalize(sham::EventList &resulting_events, shamrock::patch::PatchDataLayer &pdat) {
86 auto &buf_field = pdat.get_field<u32>(2).get_buf();
87 buf_field.complete_event_state(resulting_events);
88 }
89 };
90
91 inline void dump_patch(u64 id) {
92
93 using namespace shamrock::patch;
94 using namespace shamalgs::memory;
95
96 PatchDataLayer &pdat = grid.sched.patch_data.get_pdat(id);
97
98 std::vector<u64_3> mins = pdat.get_field<u64_3>(0).get_buf().copy_to_stdvec();
99 std::vector<u64_3> maxs = pdat.get_field<u64_3>(1).get_buf().copy_to_stdvec();
100
101 logger::raw_ln("----- dump");
102 for (u32 i = 0; i < mins.size(); i++) {
103 logger::raw_ln(mins[i], maxs[i]);
104 }
105 logger::raw_ln("-----");
106 }
107
108 static constexpr u64 fact_p_len = 2;
109
114 inline void refine() {
115
116 // dump_patch(4);
117 auto splits = grid.gen_refine_list<RefineCritCellAccessor>(
118 [](u32 cell_id, RefineCritCellAccessor acc) -> u32 {
119 u64_3 low_bound = acc.cell_low_bound[cell_id];
120 u64_3 high_bound = acc.cell_high_bound[cell_id];
121
122 using namespace shammath;
123
124 bool should_refine
125 = is_in_half_open(
126 low_bound, fact_p_len * u64_3{1, 1, 1}, fact_p_len * u64_3{4, 4, 4})
127 && is_in_half_open(
128 high_bound, fact_p_len * u64_3{1, 1, 1}, fact_p_len * u64_3{4, 4, 4});
129
130 should_refine = should_refine && (high_bound.x() - low_bound.x() > 1);
131 should_refine = should_refine && (high_bound.y() - low_bound.y() > 1);
132 should_refine = should_refine && (high_bound.z() - low_bound.z() > 1);
133
134 return should_refine;
135 });
136
138 std::move(splits),
139
140 [](u32 cur_idx,
141 Grid::CellCoord cur_coords,
142 std::array<u32, 8> new_cells,
143 std::array<Grid::CellCoord, 8> new_cells_coords,
144 RefineCellAccessor acc) {
145 u32 val = acc.field[cur_idx];
146
147#pragma unroll
148 for (u32 pid = 0; pid < 8; pid++) {
149 acc.field[new_cells[pid]] = val;
150 }
151 }
152
153 );
154
155 // dump_patch(4);
156 }
157
158 inline void derefine() {
159 auto merge = grid.gen_merge_list<RefineCritCellAccessor>(
160 [](u32 cell_id, RefineCritCellAccessor acc) -> u32 {
161 u64_3 low_bound = acc.cell_low_bound[cell_id];
162 u64_3 high_bound = acc.cell_high_bound[cell_id];
163
164 using namespace shammath;
165
166 bool should_merge
167 = is_in_half_open(
168 low_bound, fact_p_len * u64_3{1, 1, 1}, fact_p_len * u64_3{4, 4, 4})
170 high_bound, fact_p_len * u64_3{1, 1, 1}, fact_p_len * u64_3{4, 4, 4});
171
172 return should_merge;
173 });
174
175 grid.apply_merge<RefineCellAccessor>(
176 std::move(merge),
177
178 [](std::array<u32, 8> old_cells,
179 std::array<Grid::CellCoord, 8> old_coords,
180 u32 new_cell,
181 Grid::CellCoord new_coord,
182
183 RefineCellAccessor acc) {
184 u32 accum = 0;
185
186#pragma unroll
187 for (u32 pid = 0; pid < 8; pid++) {
188 accum += acc.field[old_cells[pid]];
189 }
190
191 acc.field[new_cell] = accum / 8;
192 }
193
194 );
195 // dump_patch(4);
196 }
197
198 inline void step() {
199
200 using namespace shamrock::patch;
201
202 refine();
203 derefine();
204
205 using namespace shamrock::patch;
206
207 sham::DeviceQueue &q = shamsys::instance::get_compute_scheduler().get_queue();
208
209 grid.sched.for_each_patch_data([&](u64 id_patch, Patch cur_p, PatchDataLayer &pdat) {
211 shamsys::instance::get_compute_scheduler_ptr(),
212 grid.sched.get_sim_box().patch_coord_to_domain<u64_3>(cur_p),
213 pdat.get_field<u64_3>(0).get_buf(),
214 pdat.get_obj_cnt(),
215 0);
216
217 tree.compute_cell_ibounding_box(q.q);
218
219 tree.convert_bounding_box(q.q);
220
221 class WalkAccessors {
222 public:
223 u32 *field;
224
225 WalkAccessors(
227 auto &buf_field = pdat.get_field<u32>(2).get_buf();
228 field = buf_field.get_write_access(depends_list);
229 }
230
231 void finalize(
232 sham::EventList &resulting_events, shamrock::patch::PatchDataLayer &pdat) {
233 auto &buf_field = pdat.get_field<u32>(2).get_buf();
234 buf_field.complete_event_state(resulting_events);
235 }
236 };
237
238 q.q.wait();
239
241 t.start();
242
243 sham::EventList depends_list;
244 sham::EventList resulting_events;
245
246 WalkAccessors uacc(depends_list, pdat);
247
248 auto cell_low_bound = pdat.get_field<u64_3>(0).get_buf().get_read_access(depends_list);
249 auto cell_high_bound = pdat.get_field<u64_3>(1).get_buf().get_read_access(depends_list);
250
251 auto e = q.submit(depends_list, [&](sycl::handler &cgh) {
253 Rta tree_acc(tree, cgh);
254
255 sycl::range range_npart{pdat.get_obj_cnt()};
256
257 cgh.parallel_for(range_npart, [=](sycl::item<1> item) {
258 u64_3 low_bound_a = cell_low_bound[item];
259 u64_3 high_bound_a = cell_high_bound[item];
260
261 u32 sum = 0;
262
263 walker::rtree_for(
264 tree_acc,
265 [&](u32 node_id) {
266 u64_3 cur_pos_min_cell_b = tree_acc.pos_min_cell[node_id];
267 u64_3 cur_pos_max_cell_b = tree_acc.pos_max_cell[node_id];
268
270 low_bound_a, high_bound_a, cur_pos_min_cell_b, cur_pos_max_cell_b);
271 },
272 [&](u32 id_b) {
273 // compute only omega_a
274
275 sum += 1;
276 },
277 [](u32 node_id) {});
278
279 uacc.field[item] = sum;
280 });
281 });
282
283 resulting_events.add_event(e);
284 uacc.finalize(resulting_events, pdat);
285 pdat.get_field<u64_3>(0).get_buf().complete_event_state(e);
286 pdat.get_field<u64_3>(1).get_buf().complete_event_state(e);
287
288 q.q.wait();
289 t.end();
290
291 shamlog_debug_ln("AMR Test", "walk time", t.get_time_str());
292
293 class InteractionCrit {
294 public:
296
298 PatchDataLayer &pdat;
299
300 sycl::buffer<u64_3> buf_cell_low_bound;
301 sycl::buffer<u64_3> buf_cell_high_bound;
302
303 class Access {
304 public:
305 sycl::accessor<u64_3, 1, sycl::access::mode::read> cell_low_bound;
306 sycl::accessor<u64_3, 1, sycl::access::mode::read> cell_high_bound;
307
308 sycl::accessor<u64_3, 1, sycl::access::mode::read> tree_cell_coordrange_min;
309 sycl::accessor<u64_3, 1, sycl::access::mode::read> tree_cell_coordrange_max;
310
311 Access(InteractionCrit crit, sycl::handler &cgh)
312 : cell_low_bound{crit.buf_cell_low_bound, cgh, sycl::read_only},
313 cell_high_bound{crit.buf_cell_high_bound, cgh, sycl::read_only},
314 tree_cell_coordrange_min{
315 *crit.tree.tree_cell_ranges.buf_pos_min_cell_flt,
316 cgh,
317 sycl::read_only},
318 tree_cell_coordrange_max{
319 *crit.tree.tree_cell_ranges.buf_pos_max_cell_flt,
320 cgh,
321 sycl::read_only} {}
322
323 class ObjectValues {
324 public:
325 u64_3 cell_low_bound;
326 u64_3 cell_high_bound;
327 ObjectValues(Access acc, u32 index)
328 : cell_low_bound(acc.cell_low_bound[index]),
329 cell_high_bound(acc.cell_high_bound[index]) {}
330 };
331 };
332
333 static bool criterion(
334 u32 node_index, Access acc, Access::ObjectValues current_values) {
335 u64_3 cur_pos_min_cell_b = acc.tree_cell_coordrange_min[node_index];
336 u64_3 cur_pos_max_cell_b = acc.tree_cell_coordrange_max[node_index];
337
339 current_values.cell_low_bound,
340 current_values.cell_high_bound,
341 cur_pos_min_cell_b,
342 cur_pos_max_cell_b);
343 };
344 };
345
346 using Criterion = InteractionCrit;
347 using CriterionAcc = typename Criterion::Access;
348 using CriterionVal = typename CriterionAcc::ObjectValues;
349
350 using namespace shamrock::tree;
351
352 TreeStructureWalker walk = generate_walk<Recompute>(
353 tree.tree_struct,
354 pdat.get_obj_cnt(),
355 InteractionCrit{
356 {},
357 tree,
358 pdat,
359 pdat.get_field<u64_3>(0).get_buf().copy_to_sycl_buffer(),
360 pdat.get_field<u64_3>(1).get_buf().copy_to_sycl_buffer()});
361
362 q.submit([&](sycl::handler &cgh) {
363 auto walker = walk.get_access(cgh);
364 auto leaf_iterator = tree.get_leaf_access(cgh);
365
366 cgh.parallel_for(walker.get_sycl_range(), [=](sycl::item<1> item) {
367 u32 sum = 0;
368
369 CriterionVal int_values{
370 walker.criterion(), static_cast<u32>(item.get_linear_id())};
371
372 walker.for_each_node(
373 item,
374 int_values,
375 [&](u32 /*node_id*/, u32 leaf_iterator_id) {
376 leaf_iterator.iter_object_in_leaf(
377 leaf_iterator_id, [&](u32 /*obj_id*/) {
378 sum += 1;
379 });
380 },
381 [&](u32 node_id) {});
382 });
383 });
384 });
385 }
386};
Header file describing a Node Instance.
std::uint32_t u32
32 bit unsigned integer
std::uint64_t u64
64 bit unsigned integer
void refine()
does the refinment step of the AMR
void for_each_patch_data(Function &&fct)
for each macro for patchadata example usage
SchedulerPatchData patch_data
handle the data of the patches of the scheduler
The radix tree.
Definition RadixTree.hpp:50
A buffer allocated in USM (Unified Shared Memory)
void complete_event_state(sycl::event e) const
Complete the event state of the buffer.
const T * get_read_access(sham::EventList &depends_list, SourceLocation src_loc=SourceLocation{}) const
Get a read-only pointer to the buffer's data.
A SYCL queue associated with a device and a context.
sycl::queue q
The SYCL queue associated with this context.
sycl::event submit(Fct &&fct)
Submits a kernel to the SYCL queue.
DeviceQueue & get_queue(u32 id=0)
Get a reference to a DeviceQueue.
Class to manage a list of SYCL events.
Definition EventList.hpp:31
void add_event(sycl::event e)
Add an event to the list of events.
Definition EventList.hpp:87
Class Timer measures the time elapsed since the timer was started.
Definition time.hpp:96
std::string get_time_str() const
Converts the stored nanosecond time to a string representation.
Definition time.hpp:117
void end()
Stops the timer and stores the elapsed time in nanoseconds.
Definition time.hpp:111
void start()
Starts the timer.
Definition time.hpp:106
The AMR grid only sees the grid as an integer map.
Definition AMRGrid.hpp:44
void apply_splits(shambase::DistributedData< OptIndexList > &&splts, Fct &&lambd)
Definition AMRGrid.hpp:274
PatchDataLayer container class, the layout is described in patchdata_layout.
std::tuple< T, T > patch_coord_to_domain(const Patch &p) const
get the patch coordinates on the domain
Definition SimBox.hpp:300
memory manipulation algorithms
namespace for math utility
Definition AABB.hpp:26
bool domain_are_connected(T bmin1, T bmax1, T bmin2, T bmax2)
Check if two 1D intervals share boundary or overlap.
bool is_in_half_open(T val, T min, T max)
return true if val is in [min,max[
Definition intervals.hpp:36
main include file for memory algorithms
Patch object that contain generic patch information.
Definition Patch.hpp:33