Shamrock 2025.10.0
Astrophysical Code
Loading...
Searching...
No Matches
AMRGrid.hpp
Go to the documentation of this file.
1// -------------------------------------------------------//
2//
3// SHAMROCK code for hydrodynamics
4// Copyright (c) 2021-2026 Timothée David--Cléris <tim.shamrock@proton.me>
5// SPDX-License-Identifier: CeCILL Free Software License Agreement v2.1
6// Shamrock is licensed under the CeCILL 2.1 License, see LICENSE for more information
7//
8// -------------------------------------------------------//
9
10#pragma once
11
20#include "AMRCell.hpp"
22#include "shamalgs/memory.hpp"
23#include "shamalgs/numeric.hpp"
28#include <vector>
29
30namespace shamrock::amr {
31
32 struct OptIndexList {
33 std::optional<sycl::buffer<u32>> idx;
34 u32 count;
35 };
36
43 template<class Tcoord, u32 dim>
44 class AMRGrid {
45 public:
46 PatchScheduler &sched;
47
49 static constexpr u32 dimension = dim;
50 static constexpr u32 split_count = CellCoord::splts_count;
51
52 void check_amr_main_fields() {
53
54 bool correct_type = true;
55 correct_type &= sched.pdl_old().template check_field_type<Tcoord>(0);
56 correct_type &= sched.pdl_old().template check_field_type<Tcoord>(1);
57
58 bool correct_names = true;
59 correct_names &= sched.pdl_old().template get_field<Tcoord>(0).name == "cell_min";
60 correct_names &= sched.pdl_old().template get_field<Tcoord>(1).name == "cell_max";
61
62 if (!correct_type || !correct_names) {
63 throw std::runtime_error(
64 "the amr module require a layout in the form :\n"
65 " 0 : cell_min : nvar=1 type : (Coordinate type)\n"
66 " 1 : cell_max : nvar=1 type : (Coordinate type)\n\n"
67 "the current layout is : \n"
68 + sched.pdl_old().get_description_str());
69 }
70 }
71
72 explicit AMRGrid(PatchScheduler &scheduler) : sched(scheduler) { check_amr_main_fields(); }
73
89 std::function<void(u64, patch::Patch, patch::PatchDataLayer &, sycl::buffer<u32> &)>
90 fct) {
91
93
94 using namespace patch;
95
96 u64 tot_refine = 0;
97
98 sched.for_each_patch_data([&](u64 id_patch, Patch cur_p, PatchDataLayer &pdat) {
99 sycl::queue &q = shamsys::instance::get_compute_queue();
100
101 u32 obj_cnt = pdat.get_obj_cnt();
102
103 sycl::buffer<u32> refine_flags(obj_cnt);
104
105 // fill in the refinment flags
106 fct(id_patch, cur_p, pdat, refine_flags);
107
108 // perform stream compactions on the refinement flags
109 auto [buf, len] = shamalgs::numeric::stream_compact(q, refine_flags, obj_cnt);
110
111 shamlog_debug_ln("AMRGrid", "patch ", id_patch, "refine cell count = ", len);
112
113 tot_refine += len;
114
115 // add the results to the map
116 ret.add_obj(id_patch, OptIndexList{std::move(buf), len});
117 });
118
119 logger::info_ln("AMRGrid", "on this process", tot_refine, "cells were refined");
120
121 return std::move(ret);
122 }
123
124 template<class UserAcc, class Fct, class... T>
125 inline shambase::DistributedData<OptIndexList> gen_refine_list(Fct &&lambd, T &&...args) {
126 using namespace shamrock::patch;
127
128 return gen_refinelists_native([&](u64 id_patch,
129 Patch p,
130 PatchDataLayer &pdat,
131 sycl::buffer<u32> &refine_flags) {
132 sham::DeviceQueue &q = shamsys::instance::get_compute_scheduler().get_queue();
133
134 sham::EventList depends_list;
135 sham::EventList resulting_events;
136
137 UserAcc uacc(depends_list, id_patch, p, pdat, args...);
138
139 auto e = q.submit(depends_list, [&](sycl::handler &cgh) {
140 sycl::accessor refine_acc{refine_flags, cgh, sycl::write_only, sycl::no_init};
141
142 cgh.parallel_for(sycl::range<1>(pdat.get_obj_cnt()), [=](sycl::item<1> gid) {
143 refine_acc[gid] = lambd(gid.get_linear_id(), uacc);
144 });
145 });
146
147 resulting_events.add_event(e);
148 uacc.finalize(resulting_events, id_patch, p, pdat, args...);
149 });
150 }
151
152 inline u64 get_process_refine_count(shambase::DistributedData<OptIndexList> &splits) {
153 u64 acc = 0;
154
155 splits.for_each([&acc](u64 id, OptIndexList &idx_list) {
156 acc += idx_list.count;
157 });
158
159 return acc;
160 }
161
162 template<class UserAcc, class Fct>
163 shambase::DistributedData<OptIndexList> gen_merge_list(Fct &&lambd) {
164
166 u64 tot_merge = 0;
167
168 using MortonBuilder = RadixTreeMortonBuilder<u64, Tcoord, 3>;
169 using namespace shamrock::patch;
170
171 sched.for_each_patch_data([&](u64 id_patch, Patch cur_p, PatchDataLayer &pdat) {
172 // return because no cell can be merged since
173 if (pdat.get_obj_cnt() < split_count) {
174 return;
175 }
176
177 std::unique_ptr<sycl::buffer<u64>> out_buf_morton;
178 std::unique_ptr<sycl::buffer<u32>> out_buf_particle_index_map;
179
180 MortonBuilder::build(
181 shamsys::instance::get_compute_scheduler_ptr(),
182 sched.get_sim_box().template patch_coord_to_domain<Tcoord>(cur_p),
183 pdat.get_field<Tcoord>(0).get_buf(),
184 pdat.get_obj_cnt(),
185 out_buf_morton,
186 out_buf_particle_index_map);
187
188 // apply list permut on patch
189
190 u32 pre_merge_obj_cnt = pdat.get_obj_cnt();
191
192 pdat.index_remap(*out_buf_particle_index_map, pre_merge_obj_cnt);
193
194 u32 obj_to_check = pre_merge_obj_cnt - split_count + 1;
195
196 shamlog_debug_sycl_ln("AMR Grid", "checking mergeable in", obj_to_check, "cells");
197
198 sycl::buffer<u32> mergeable_indexes(obj_to_check);
199
200 sham::DeviceQueue &q = shamsys::instance::get_compute_scheduler().get_queue();
201
202 sham::EventList depends_list;
203 auto acc_min = pdat.get_field<Tcoord>(0).get_buf().get_write_access(depends_list);
204 auto acc_max = pdat.get_field<Tcoord>(1).get_buf().get_write_access(depends_list);
205
206 auto e = q.submit(depends_list, [&](sycl::handler &cgh) {
207 sycl::accessor acc_mergeable{
208 mergeable_indexes, cgh, sycl::write_only, sycl::no_init};
209
210 sycl::range<1> rnge{obj_to_check};
211
212 cgh.parallel_for(rnge, [=](sycl::item<1> gid) {
213 u32 id = gid.get_linear_id();
214
215 std::array<CellCoord, split_count> cells;
216
217 for (u32 lid = 0; lid < split_count; lid++) {
218 cells[lid] = CellCoord{acc_min[gid + lid], acc_max[gid + lid]};
219 }
220
221 acc_mergeable[gid] = CellCoord::are_mergeable(cells);
222 });
223 });
224
225 pdat.get_field<Tcoord>(0).get_buf().complete_event_state(e);
226 pdat.get_field<Tcoord>(1).get_buf().complete_event_state(e);
227
228 {
229 sham::EventList depends_list;
230 sham::EventList resulting_events;
231 UserAcc uacc(depends_list, id_patch, cur_p, pdat);
232
233 auto e2 = q.submit(depends_list, [&](sycl::handler &cgh) {
234 sycl::accessor acc_mergeable{mergeable_indexes, cgh, sycl::read_write};
235
236 cgh.parallel_for(
237 sycl::range<1>(pdat.get_obj_cnt()), [=](sycl::item<1> gid) {
238 if (acc_mergeable[gid]) {
239 acc_mergeable[gid] = lambd(gid.get_linear_id(), uacc);
240 }
241 });
242 });
243
244 resulting_events.add_event(e2);
245 uacc.finalize(resulting_events, id_patch, cur_p, pdat);
246 }
247
248 auto [opt_buf, len] = shamalgs::numeric::stream_compact(
249 shamsys::instance::get_compute_queue(), mergeable_indexes, obj_to_check);
250
251 shamlog_debug_ln("AMRGrid", "patch ", id_patch, "merge cell count = ", len);
252
253 tot_merge += len;
254
255 // add the results to the map
256 ret.add_obj(id_patch, OptIndexList{std::move(opt_buf), len});
257 });
258
259 logger::info_ln(
260 "AMRGrid", "on this process", tot_merge * split_count, "cells were derefined");
261
262 return std::move(ret);
263 }
264
273 template<class UserAcc, class Fct>
275
276 using namespace patch;
277
278 u64 sum_cell_count = 0;
279
280 sched.for_each_patch_data([&](u64 id_patch, Patch cur_p, PatchDataLayer &pdat) {
281 sycl::queue &q = shamsys::instance::get_compute_queue();
282
283 u32 old_obj_cnt = pdat.get_obj_cnt();
284
285 OptIndexList &refine_flags = splts.get(id_patch);
286
287 if (refine_flags.count > 0) {
288
289 pdat.expand(refine_flags.count * (split_count - 1));
290
291 sham::DeviceQueue &q = shamsys::instance::get_compute_scheduler().get_queue();
292
293 sham::EventList depends_list;
294
295 auto cell_bound_low
296 = pdat.get_field<Tcoord>(0).get_buf().get_write_access(depends_list);
297 auto cell_bound_high
298 = pdat.get_field<Tcoord>(1).get_buf().get_write_access(depends_list);
299
300 sham::EventList resulting_events;
301 UserAcc uacc(depends_list, pdat);
302
303 auto e = q.submit(depends_list, [&](sycl::handler &cgh) {
304 sycl::accessor index_to_ref{*refine_flags.idx, cgh, sycl::read_only};
305
306 u32 start_index_push = old_obj_cnt;
307
308 constexpr u32 new_splits = split_count - 1;
309
310 cgh.parallel_for(
311 sycl::range<1>(refine_flags.count), [=](sycl::item<1> gid) {
312 u32 tid = gid.get_linear_id();
313
314 u32 idx_to_refine = index_to_ref[gid];
315
316 // gen splits coordinates
317 CellCoord cur_cell{
318 cell_bound_low[idx_to_refine], cell_bound_high[idx_to_refine]};
319
320 std::array<CellCoord, split_count> cell_coords
321 = CellCoord::get_split(cur_cell.bmin, cur_cell.bmax);
322
323 // generate index for the refined cells
324 std::array<u32, split_count> cells_ids;
325 cells_ids[0] = idx_to_refine;
326
327#pragma unroll
328 for (u32 pid = 0; pid < new_splits; pid++) {
329 cells_ids[pid + 1] = start_index_push + tid * new_splits + pid;
330 }
331
332 // write coordinates
333
334#pragma unroll
335 for (u32 pid = 0; pid < split_count; pid++) {
336 cell_bound_low[cells_ids[pid]] = cell_coords[pid].bmin;
337 cell_bound_high[cells_ids[pid]] = cell_coords[pid].bmax;
338 }
339
340 // user lambda to fill the fields
341 lambd(idx_to_refine, cur_cell, cells_ids, cell_coords, uacc);
342 });
343 });
344
345 pdat.get_field<Tcoord>(0).get_buf().complete_event_state(e);
346 pdat.get_field<Tcoord>(1).get_buf().complete_event_state(e);
347
348 resulting_events.add_event(e);
349 uacc.finalize(resulting_events, pdat);
350 }
351
352 sum_cell_count += pdat.get_obj_cnt();
353 });
354
355 logger::info_ln("AMRGrid", "process cell count =", sum_cell_count);
356 }
357
358 template<class UserAcc, class Fct>
359 void apply_merge(shambase::DistributedData<OptIndexList> &&splts, Fct &&lambd) {
360
361 using namespace patch;
362
363 sched.for_each_patch_data([&](u64 id_patch, Patch cur_p, PatchDataLayer &pdat) {
364 sham::DeviceQueue &q = shamsys::instance::get_compute_scheduler().get_queue();
365
366 u32 old_obj_cnt = pdat.get_obj_cnt();
367
368 OptIndexList &derefine_flags = splts.get(id_patch);
369
370 if (derefine_flags.count > 0) {
371
372 // init flag table
373 sycl::buffer<u32> keep_cell_flag = shamalgs::algorithm::gen_buffer_device(
374 q.q, old_obj_cnt, [](u32 i) -> u32 {
375 return 1;
376 });
377
378 sham::EventList depends_list;
379 auto cell_bound_low
380 = pdat.get_field<Tcoord>(0).get_buf().get_write_access(depends_list);
381 auto cell_bound_high
382 = pdat.get_field<Tcoord>(1).get_buf().get_write_access(depends_list);
383
384 sham::EventList resulting_events;
385 UserAcc uacc(depends_list, pdat);
386
387 // edit cell content + make flag of cells to keep
388 auto e = q.submit(depends_list, [&](sycl::handler &cgh) {
389 sycl::accessor index_to_deref{*derefine_flags.idx, cgh, sycl::read_only};
390
391 sycl::accessor flag_keep{keep_cell_flag, cgh, sycl::read_write};
392
393 cgh.parallel_for(
394 sycl::range<1>(derefine_flags.count), [=](sycl::item<1> gid) {
395 u32 tid = gid.get_linear_id();
396
397 u32 idx_to_derefine = index_to_deref[gid];
398
399 // compute old cell indexes
400 std::array<u32, split_count> old_indexes;
401#pragma unroll
402 for (u32 pid = 0; pid < split_count; pid++) {
403 old_indexes[pid] = idx_to_derefine + pid;
404 }
405
406 // load cell coords
407 std::array<CellCoord, split_count> cell_coords;
408#pragma unroll
409 for (u32 pid = 0; pid < split_count; pid++) {
410 cell_coords[pid] = CellCoord{
411 cell_bound_low[old_indexes[pid]],
412 cell_bound_high[old_indexes[pid]]};
413 }
414
415 // make new cell coord
416 CellCoord merged_cell_coord = CellCoord::get_merge(cell_coords);
417
418 // write new coord
419 cell_bound_low[idx_to_derefine] = merged_cell_coord.bmin;
420 cell_bound_high[idx_to_derefine] = merged_cell_coord.bmax;
421
422// flag the old cells for removal
423#pragma unroll
424 for (u32 pid = 1; pid < split_count; pid++) {
425 flag_keep[idx_to_derefine + pid] = 0;
426 }
427
428 // user lambda to fill the fields
429 lambd(
430 old_indexes,
431 cell_coords,
432 idx_to_derefine,
433 merged_cell_coord,
434 uacc);
435 });
436 });
437
438 pdat.get_field<Tcoord>(0).get_buf().complete_event_state(e);
439 pdat.get_field<Tcoord>(1).get_buf().complete_event_state(e);
440 resulting_events.add_event(e);
441 uacc.finalize(resulting_events, pdat);
442
443 // stream compact the flags
444 auto [opt_buf, len]
445 = shamalgs::numeric::stream_compact(q.q, keep_cell_flag, old_obj_cnt);
446
447 shamlog_debug_ln(
448 "AMR Grid",
449 "patch",
450 id_patch,
451 "derefine cell count ",
452 old_obj_cnt,
453 "->",
454 len);
455
456 if (!opt_buf) {
457 throw std::runtime_error("opt buf must contain something at this point");
458 }
459
460 // remap pdat according to stream compact
461 pdat.index_remap_resize(*opt_buf, len);
462 }
463 });
464 }
465
466 inline void make_base_grid(Tcoord bmin, Tcoord cell_size, std::array<u32, dim> cell_count) {
467
468 Tcoord bmax{
469 bmin.x() + cell_size.x() * (cell_count[0]),
470 bmin.y() + cell_size.y() * (cell_count[1]),
471 bmin.z() + cell_size.z() * (cell_count[2])};
472
473 sched.set_coord_domain_bound(bmin, bmax);
474
475 if ((cell_size.x() != cell_size.y()) || (cell_size.y() != cell_size.z())) {
476 logger::warn_ln("AMR Grid", "your cells aren't cube");
477 }
478
479 static_assert(dim == 3, "this is not implemented for dim != 3");
480
481 std::array<u32, dim> patch_count;
482
483 constexpr u32 gcd_pow2 = 1U << 31U;
484 u32 gcd_cell_count;
485 {
486 gcd_cell_count = std::gcd(cell_count[0], cell_count[1]);
487 gcd_cell_count = std::gcd(gcd_cell_count, cell_count[2]);
488 gcd_cell_count = std::gcd(gcd_cell_count, gcd_pow2);
489 }
490
491 shamlog_debug_ln(
492 "AMRGrid",
493 "patch grid :",
494 cell_count[0] / gcd_cell_count,
495 cell_count[1] / gcd_cell_count,
496 cell_count[2] / gcd_cell_count);
497
498 sched.make_patch_base_grid<3>(
499 {{cell_count[0] / gcd_cell_count,
500 cell_count[1] / gcd_cell_count,
501 cell_count[2] / gcd_cell_count}});
502
503 sched.for_each_patch([](u64 id_patch, const patch::Patch &p) {
504 // TODO implement check to verify that patch a cubes of size 2^n
505 });
506
507 u32 cell_tot_count = cell_count[0] * cell_count[1] * cell_count[2];
508
509 sycl::buffer<Tcoord> cell_coord_min(cell_tot_count);
510 sycl::buffer<Tcoord> cell_coord_max(cell_tot_count);
511
512 shamlog_debug_sycl_ln(
513 "AMRGrid", "building bounds ", cell_count[0], cell_count[1], cell_count[2]);
514
515 {
516 sycl::host_accessor acc_min{cell_coord_min, sycl::write_only, sycl::no_init};
517 sycl::host_accessor acc_max{cell_coord_max, sycl::write_only, sycl::no_init};
518
519 sycl::range<3> rnge{cell_count[0], cell_count[1], cell_count[2]};
520
521 u32 cnt_x = cell_count[0];
522 u32 cnt_y = cell_count[1];
523 u32 cnt_z = cell_count[2];
524
525 u32 cnt_xy = cnt_x * cnt_y;
526
527 Tcoord sz = cell_size;
528
529 for (u64 idx = 0; idx < cell_count[0]; idx++) {
530 for (u64 idy = 0; idy < cell_count[1]; idy++) {
531 for (u64 idz = 0; idz < cell_count[2]; idz++) {
532
533 u64 id_a = idx + cnt_x * idy + cnt_xy * idz;
534
535 acc_min[id_a] = sz * Tcoord{idx, idy, idz};
536 acc_max[id_a] = sz * Tcoord{idx + 1, idy + 1, idz + 1};
537 }
538 }
539 }
540 }
541
542 shambase::check_queue_state(shamsys::instance::get_compute_queue());
543
544 patch::PatchDataLayer pdat(sched.get_layout_ptr_old());
545 pdat.resize(cell_tot_count);
546
547 shambase::check_queue_state(shamsys::instance::get_compute_queue());
548 pdat.get_field<Tcoord>(0).override(cell_coord_min, cell_tot_count);
549
550 shambase::check_queue_state(shamsys::instance::get_compute_queue());
551 pdat.get_field<Tcoord>(1).override(cell_coord_max, cell_tot_count);
552
553 shambase::check_queue_state(shamsys::instance::get_compute_queue());
554
555 sched.allpush_data(pdat);
556
557 shambase::check_queue_state(shamsys::instance::get_compute_queue());
558
559 shamlog_debug_sycl_ln("AMRGrid", "grid init done");
560 }
561 };
562
564 // out of line implementation
566
567 // template<class Tcoord, u32 dim>
568 // inline auto
569 // AMRGrid<Tcoord, dim>::gen_splitlists(std::function<sycl::buffer<u32>(u64 , patch::Patch ,
570 // patch::PatchData &)> fct) -> shambase::DistributedData<SplitList> {
571 //
572 // shambase::DistributedData<SplitList> ret;
573 //
574 // using namespace patch;
575 //
576 // sched.for_each_patch_data([&](u64 id_patch, Patch cur_p, PatchData &pdat) {
577 // sycl::queue &q = shamsys::instance::get_compute_queue();
578 //
579 // u32 obj_cnt = pdat.get_obj_cnt();
580 //
581 // sycl::buffer<u32> split_flags = fct(id_patch, cur_p, pdat);
582 //
583 // auto [buf, len] = shamalgs::numeric::stream_compact(q, split_flags, obj_cnt);
584 //
585 // ret.add_obj(id_patch, SplitList{std::move(buf), len});
586 // });
587 //
588 // return std::move(ret);
589 //}
590
591} // namespace shamrock::amr
MPI scheduler.
Utility to build morton codes for the radix tree.
std::uint32_t u32
32 bit unsigned integer
std::uint64_t u64
64 bit unsigned integer
The MPI scheduler.
void for_each_patch_data(Function &&fct)
for each macro for patchadata example usage
void set_coord_domain_bound(vectype bmin, vectype bmax)
modify the bounding box of the patch domain
void allpush_data(shamrock::patch::PatchDataLayer &pdat)
push data in the scheduler The content of pdat as to be the same for each node
Helper class to build morton codes.
A SYCL queue associated with a device and a context.
sycl::queue q
The SYCL queue associated with this context.
sycl::event submit(Fct &&fct)
Submits a kernel to the SYCL queue.
DeviceQueue & get_queue(u32 id=0)
Get a reference to a DeviceQueue.
Class to manage a list of SYCL events.
Definition EventList.hpp:31
void add_event(sycl::event e)
Add an event to the list of events.
Definition EventList.hpp:87
Represents a collection of objects distributed across patches identified by a u64 id.
iterator add_obj(u64 id, T &&obj)
Adds a new object to the collection.
void for_each(std::function< void(u64, T &)> &&f)
Applies a function to each object in the collection.
The AMR grid only sees the grid as an integer map.
Definition AMRGrid.hpp:44
void apply_splits(shambase::DistributedData< OptIndexList > &&splts, Fct &&lambd)
Definition AMRGrid.hpp:274
shambase::DistributedData< OptIndexList > gen_refinelists_native(std::function< void(u64, patch::Patch, patch::PatchDataLayer &, sycl::buffer< u32 > &)> fct)
generate split lists for all patchdata owned by the node
Definition AMRGrid.hpp:88
std::string get_description_str() const
Get the description of the layout.
PatchDataLayer container class, the layout is described in patchdata_layout.
void index_remap(sycl::buffer< u32 > &index_map, u32 len)
this function remaps the patchdatafield like so val[id] = val[index_map[id]] This function can be use...
void index_remap_resize(sycl::buffer< u32 > &index_map, u32 len)
this function remaps the patchdatafield like so val[id] = val[index_map[id]] This function can be use...
sycl::buffer< typename std::invoke_result_t< Fct, u32 > > gen_buffer_device(sycl::queue &q, u32 len, Fct &&func)
generate a buffer from a lambda expression based on the indexes
Definition algorithm.hpp:65
std::tuple< std::optional< sycl::buffer< u32 > >, u32 > stream_compact(sycl::queue &q, sycl::buffer< u32 > &buf_flags, u32 len)
Stream compaction algorithm.
Definition numeric.cpp:84
main include file for memory algorithms
Patch object that contain generic patch information.
Definition Patch.hpp:33