49 static constexpr u32 dimension = dim;
50 static constexpr u32 split_count = CellCoord::splts_count;
52 void check_amr_main_fields() {
54 bool correct_type =
true;
55 correct_type &= sched.pdl_old().template check_field_type<Tcoord>(0);
56 correct_type &= sched.pdl_old().template check_field_type<Tcoord>(1);
58 bool correct_names =
true;
59 correct_names &= sched.pdl_old().template get_field<Tcoord>(0).name ==
"cell_min";
60 correct_names &= sched.pdl_old().template get_field<Tcoord>(1).name ==
"cell_max";
62 if (!correct_type || !correct_names) {
63 throw std::runtime_error(
64 "the amr module require a layout in the form :\n"
65 " 0 : cell_min : nvar=1 type : (Coordinate type)\n"
66 " 1 : cell_max : nvar=1 type : (Coordinate type)\n\n"
67 "the current layout is : \n"
94 using namespace patch;
99 sycl::queue &q = shamsys::instance::get_compute_queue();
101 u32 obj_cnt = pdat.get_obj_cnt();
103 sycl::buffer<u32> refine_flags(obj_cnt);
106 fct(id_patch, cur_p, pdat, refine_flags);
111 shamlog_debug_ln(
"AMRGrid",
"patch ", id_patch,
"refine cell count = ", len);
119 logger::info_ln(
"AMRGrid",
"on this process", tot_refine,
"cells were refined");
121 return std::move(ret);
124 template<
class UserAcc,
class Fct,
class... T>
126 using namespace shamrock::patch;
131 sycl::buffer<u32> &refine_flags) {
137 UserAcc uacc(depends_list, id_patch, p, pdat, args...);
139 auto e = q.
submit(depends_list, [&](sycl::handler &cgh) {
140 sycl::accessor refine_acc{refine_flags, cgh, sycl::write_only, sycl::no_init};
142 cgh.parallel_for(sycl::range<1>(pdat.get_obj_cnt()), [=](sycl::item<1> gid) {
143 refine_acc[gid] = lambd(gid.get_linear_id(), uacc);
148 uacc.finalize(resulting_events, id_patch, p, pdat, args...);
155 splits.
for_each([&acc](
u64 id, OptIndexList &idx_list) {
156 acc += idx_list.count;
162 template<
class UserAcc,
class Fct>
169 using namespace shamrock::patch;
173 if (pdat.get_obj_cnt() < split_count) {
177 std::unique_ptr<sycl::buffer<u64>> out_buf_morton;
178 std::unique_ptr<sycl::buffer<u32>> out_buf_particle_index_map;
180 MortonBuilder::build(
181 shamsys::instance::get_compute_scheduler_ptr(),
182 sched.get_sim_box().template patch_coord_to_domain<Tcoord>(cur_p),
183 pdat.get_field<Tcoord>(0).get_buf(),
186 out_buf_particle_index_map);
190 u32 pre_merge_obj_cnt = pdat.get_obj_cnt();
192 pdat.
index_remap(*out_buf_particle_index_map, pre_merge_obj_cnt);
194 u32 obj_to_check = pre_merge_obj_cnt - split_count + 1;
196 shamlog_debug_sycl_ln(
"AMR Grid",
"checking mergeable in", obj_to_check,
"cells");
198 sycl::buffer<u32> mergeable_indexes(obj_to_check);
203 auto acc_min = pdat.get_field<Tcoord>(0).get_buf().get_write_access(depends_list);
204 auto acc_max = pdat.get_field<Tcoord>(1).get_buf().get_write_access(depends_list);
206 auto e = q.
submit(depends_list, [&](sycl::handler &cgh) {
207 sycl::accessor acc_mergeable{
208 mergeable_indexes, cgh, sycl::write_only, sycl::no_init};
210 sycl::range<1> rnge{obj_to_check};
212 cgh.parallel_for(rnge, [=](sycl::item<1> gid) {
213 u32 id = gid.get_linear_id();
215 std::array<CellCoord, split_count> cells;
217 for (
u32 lid = 0; lid < split_count; lid++) {
218 cells[lid] = CellCoord{acc_min[gid + lid], acc_max[gid + lid]};
221 acc_mergeable[gid] = CellCoord::are_mergeable(cells);
225 pdat.get_field<Tcoord>(0).get_buf().complete_event_state(e);
226 pdat.get_field<Tcoord>(1).get_buf().complete_event_state(e);
231 UserAcc uacc(depends_list, id_patch, cur_p, pdat);
233 auto e2 = q.
submit(depends_list, [&](sycl::handler &cgh) {
234 sycl::accessor acc_mergeable{mergeable_indexes, cgh, sycl::read_write};
237 sycl::range<1>(pdat.get_obj_cnt()), [=](sycl::item<1> gid) {
238 if (acc_mergeable[gid]) {
239 acc_mergeable[gid] = lambd(gid.get_linear_id(), uacc);
245 uacc.finalize(resulting_events, id_patch, cur_p, pdat);
249 shamsys::instance::get_compute_queue(), mergeable_indexes, obj_to_check);
251 shamlog_debug_ln(
"AMRGrid",
"patch ", id_patch,
"merge cell count = ", len);
256 ret.
add_obj(id_patch, OptIndexList{std::move(opt_buf), len});
260 "AMRGrid",
"on this process", tot_merge * split_count,
"cells were derefined");
262 return std::move(ret);
273 template<
class UserAcc,
class Fct>
276 using namespace patch;
278 u64 sum_cell_count = 0;
281 sycl::queue &q = shamsys::instance::get_compute_queue();
283 u32 old_obj_cnt = pdat.get_obj_cnt();
287 if (refine_flags.count > 0) {
289 pdat.expand(refine_flags.count * (split_count - 1));
296 = pdat.get_field<Tcoord>(0).get_buf().get_write_access(depends_list);
298 = pdat.get_field<Tcoord>(1).get_buf().get_write_access(depends_list);
301 UserAcc uacc(depends_list, pdat);
303 auto e = q.
submit(depends_list, [&](sycl::handler &cgh) {
304 sycl::accessor index_to_ref{*refine_flags.idx, cgh, sycl::read_only};
306 u32 start_index_push = old_obj_cnt;
308 constexpr u32 new_splits = split_count - 1;
311 sycl::range<1>(refine_flags.count), [=](sycl::item<1> gid) {
312 u32 tid = gid.get_linear_id();
314 u32 idx_to_refine = index_to_ref[gid];
318 cell_bound_low[idx_to_refine], cell_bound_high[idx_to_refine]};
320 std::array<CellCoord, split_count> cell_coords
321 = CellCoord::get_split(cur_cell.bmin, cur_cell.bmax);
324 std::array<u32, split_count> cells_ids;
325 cells_ids[0] = idx_to_refine;
328 for (
u32 pid = 0; pid < new_splits; pid++) {
329 cells_ids[pid + 1] = start_index_push + tid * new_splits + pid;
335 for (
u32 pid = 0; pid < split_count; pid++) {
336 cell_bound_low[cells_ids[pid]] = cell_coords[pid].bmin;
337 cell_bound_high[cells_ids[pid]] = cell_coords[pid].bmax;
341 lambd(idx_to_refine, cur_cell, cells_ids, cell_coords, uacc);
345 pdat.get_field<Tcoord>(0).get_buf().complete_event_state(e);
346 pdat.get_field<Tcoord>(1).get_buf().complete_event_state(e);
349 uacc.finalize(resulting_events, pdat);
352 sum_cell_count += pdat.get_obj_cnt();
355 logger::info_ln(
"AMRGrid",
"process cell count =", sum_cell_count);
358 template<
class UserAcc,
class Fct>
361 using namespace patch;
366 u32 old_obj_cnt = pdat.get_obj_cnt();
368 OptIndexList &derefine_flags = splts.get(id_patch);
370 if (derefine_flags.count > 0) {
374 q.
q, old_obj_cnt, [](
u32 i) ->
u32 {
380 = pdat.get_field<Tcoord>(0).get_buf().get_write_access(depends_list);
382 = pdat.get_field<Tcoord>(1).get_buf().get_write_access(depends_list);
385 UserAcc uacc(depends_list, pdat);
388 auto e = q.
submit(depends_list, [&](sycl::handler &cgh) {
389 sycl::accessor index_to_deref{*derefine_flags.idx, cgh, sycl::read_only};
391 sycl::accessor flag_keep{keep_cell_flag, cgh, sycl::read_write};
394 sycl::range<1>(derefine_flags.count), [=](sycl::item<1> gid) {
395 u32 tid = gid.get_linear_id();
397 u32 idx_to_derefine = index_to_deref[gid];
400 std::array<u32, split_count> old_indexes;
402 for (u32 pid = 0; pid < split_count; pid++) {
403 old_indexes[pid] = idx_to_derefine + pid;
407 std::array<CellCoord, split_count> cell_coords;
409 for (
u32 pid = 0; pid < split_count; pid++) {
410 cell_coords[pid] = CellCoord{
411 cell_bound_low[old_indexes[pid]],
412 cell_bound_high[old_indexes[pid]]};
416 CellCoord merged_cell_coord = CellCoord::get_merge(cell_coords);
419 cell_bound_low[idx_to_derefine] = merged_cell_coord.bmin;
420 cell_bound_high[idx_to_derefine] = merged_cell_coord.bmax;
424 for (
u32 pid = 1; pid < split_count; pid++) {
425 flag_keep[idx_to_derefine + pid] = 0;
438 pdat.get_field<Tcoord>(0).get_buf().complete_event_state(e);
439 pdat.get_field<Tcoord>(1).get_buf().complete_event_state(e);
441 uacc.finalize(resulting_events, pdat);
451 "derefine cell count ",
457 throw std::runtime_error(
"opt buf must contain something at this point");
466 inline void make_base_grid(Tcoord bmin, Tcoord cell_size, std::array<u32, dim> cell_count) {
469 bmin.x() + cell_size.x() * (cell_count[0]),
470 bmin.y() + cell_size.y() * (cell_count[1]),
471 bmin.z() + cell_size.z() * (cell_count[2])};
475 if ((cell_size.x() != cell_size.y()) || (cell_size.y() != cell_size.z())) {
476 logger::warn_ln(
"AMR Grid",
"your cells aren't cube");
479 static_assert(dim == 3,
"this is not implemented for dim != 3");
481 std::array<u32, dim> patch_count;
483 constexpr u32 gcd_pow2 = 1U << 31U;
486 gcd_cell_count = std::gcd(cell_count[0], cell_count[1]);
487 gcd_cell_count = std::gcd(gcd_cell_count, cell_count[2]);
488 gcd_cell_count = std::gcd(gcd_cell_count, gcd_pow2);
494 cell_count[0] / gcd_cell_count,
495 cell_count[1] / gcd_cell_count,
496 cell_count[2] / gcd_cell_count);
498 sched.make_patch_base_grid<3>(
499 {{cell_count[0] / gcd_cell_count,
500 cell_count[1] / gcd_cell_count,
501 cell_count[2] / gcd_cell_count}});
507 u32 cell_tot_count = cell_count[0] * cell_count[1] * cell_count[2];
509 sycl::buffer<Tcoord> cell_coord_min(cell_tot_count);
510 sycl::buffer<Tcoord> cell_coord_max(cell_tot_count);
512 shamlog_debug_sycl_ln(
513 "AMRGrid",
"building bounds ", cell_count[0], cell_count[1], cell_count[2]);
516 sycl::host_accessor acc_min{cell_coord_min, sycl::write_only, sycl::no_init};
517 sycl::host_accessor acc_max{cell_coord_max, sycl::write_only, sycl::no_init};
519 sycl::range<3> rnge{cell_count[0], cell_count[1], cell_count[2]};
521 u32 cnt_x = cell_count[0];
522 u32 cnt_y = cell_count[1];
523 u32 cnt_z = cell_count[2];
525 u32 cnt_xy = cnt_x * cnt_y;
527 Tcoord sz = cell_size;
529 for (
u64 idx = 0; idx < cell_count[0]; idx++) {
530 for (
u64 idy = 0; idy < cell_count[1]; idy++) {
531 for (
u64 idz = 0; idz < cell_count[2]; idz++) {
533 u64 id_a = idx + cnt_x * idy + cnt_xy * idz;
535 acc_min[id_a] = sz * Tcoord{idx, idy, idz};
536 acc_max[id_a] = sz * Tcoord{idx + 1, idy + 1, idz + 1};
542 shambase::check_queue_state(shamsys::instance::get_compute_queue());
545 pdat.resize(cell_tot_count);
547 shambase::check_queue_state(shamsys::instance::get_compute_queue());
548 pdat.get_field<Tcoord>(0).
override(cell_coord_min, cell_tot_count);
550 shambase::check_queue_state(shamsys::instance::get_compute_queue());
551 pdat.get_field<Tcoord>(1).
override(cell_coord_max, cell_tot_count);
553 shambase::check_queue_state(shamsys::instance::get_compute_queue());
557 shambase::check_queue_state(shamsys::instance::get_compute_queue());
559 shamlog_debug_sycl_ln(
"AMRGrid",
"grid init done");