23template<
class Tvec,
class Tgr
idVec>
24template<
class UserAcc,
class... T>
31 using namespace shamrock::patch;
37 auto dev_sched = shamsys::instance::get_compute_scheduler_ptr();
43 u32 obj_cnt = pdat.get_obj_cnt();
51 UserAcc uacc(depends_list, id_patch, cur_p, pdat,
args...);
53 auto refine_acc = refine_flags.get_write_access(depends_list);
54 auto derefine_acc = derefine_flags.get_write_access(depends_list);
57 auto e = q.
submit(depends_list, [&](sycl::handler &cgh) {
58 cgh.parallel_for(sycl::range<1>(obj_cnt), [=](sycl::item<1> gid) {
59 bool flag_refine =
false;
60 bool flag_derefine =
false;
61 uacc.refine_criterion(gid.get_linear_id(), uacc, flag_refine, flag_derefine);
64 if (flag_refine && flag_derefine) {
65 flag_derefine =
false;
68 refine_acc[gid] = (flag_refine) ? 1 : 0;
69 derefine_acc[gid] = (flag_derefine) ? 1 : 0;
76 refine_flags.complete_event_state(resulting_events);
77 derefine_flags.complete_event_state(resulting_events);
79 uacc.finalize(resulting_events, id_patch, cur_p, pdat,
args...);
87 auto acc_max = buf_cell_max.get_read_access(depends_list);
88 auto acc_merge_flag = derefine_flags.get_write_access(depends_list);
91 auto e = q.
submit(depends_list, [&](sycl::handler &cgh) {
92 cgh.parallel_for(sycl::range<1>(obj_cnt), [=](sycl::item<1> gid) {
93 u32 id = gid.get_linear_id();
95 std::array<BlockCoord, split_count> blocks;
100 if (
id + split_count <= obj_cnt) {
101 bool all_want_to_merge =
true;
103 for (
u32 lid = 0; lid < split_count; lid++) {
104 blocks[lid] = BlockCoord{acc_min[gid + lid], acc_max[gid + lid]};
105 all_want_to_merge = all_want_to_merge && acc_merge_flag[gid + lid];
108 do_merge = all_want_to_merge && BlockCoord::are_mergeable(blocks);
114 acc_merge_flag[gid] = do_merge;
119 buf_cell_max.complete_event_state(e);
120 derefine_flags.complete_event_state(e);
130 "AMRGrid",
"patch ", id_patch,
"refine block count = ", buf_refine.get_size());
132 tot_refine += buf_refine.get_size();
135 dd_refine_list.add_obj(id_patch, std::move(buf_refine));
145 "AMRGrid",
"patch ", id_patch,
"merge block count = ", buf_derefine.get_size());
147 tot_derefine += buf_derefine.get_size();
150 dd_derefine_list.add_obj(id_patch, std::move(buf_derefine));
153 logger::info_ln(
"AMRGrid",
"on this process", tot_refine,
"blocks were refined");
155 "AMRGrid",
"on this process", tot_derefine * split_count,
"blocks were derefined");
157template<
class Tvec,
class Tgr
idVec>
158template<
class UserAcc>
162 using namespace shamrock::patch;
164 u64 sum_block_count = 0;
166 bool new_cell_were_added =
false;
171 u32 old_obj_cnt = pdat.get_obj_cnt();
178 pdat.expand(refine_flags.
get_size() * (split_count - 1));
185 auto block_bound_high = buf_cell_max.get_write_access(depends_list);
186 UserAcc uacc(depends_list, pdat);
190 auto e = q.
submit(depends_list, [&](sycl::handler &cgh) {
191 u32 start_index_push = old_obj_cnt;
193 constexpr u32 new_splits = split_count - 1;
195 cgh.parallel_for(sycl::range<1>(refine_flags.
get_size()), [=](sycl::item<1> gid) {
196 u32 tid = gid.get_linear_id();
198 u32 idx_to_refine = index_to_ref[tid];
201 BlockCoord cur_block{
202 block_bound_low[idx_to_refine], block_bound_high[idx_to_refine]};
204 std::array<BlockCoord, split_count> block_coords
205 = BlockCoord::get_split(cur_block.bmin, cur_block.bmax);
208 std::array<u32, split_count> blocks_ids;
209 blocks_ids[0] = idx_to_refine;
214 for (
u32 pid = 0; pid < new_splits; pid++) {
215 blocks_ids[pid + 1] = start_index_push + tid * new_splits + pid;
221 for (
u32 pid = 0; pid < split_count; pid++) {
222 block_bound_low[blocks_ids[pid]] = block_coords[pid].bmin;
223 block_bound_high[blocks_ids[pid]] = block_coords[pid].bmax;
227 uacc.apply_refine(idx_to_refine, cur_block, blocks_ids, block_coords, uacc);
234 buf_cell_max.complete_event_state(resulting_events);
236 uacc.finalize(resulting_events, pdat);
241 sum_block_count += pdat.get_obj_cnt();
242 new_cell_were_added = new_cell_were_added || refine_flags.
get_size() > 0;
245 logger::info_ln(
"AMRGrid",
"process block count =", sum_block_count);
247 return new_cell_were_added;
250template<
class Tvec,
class Tgr
idVec>
251template<
class UserAcc>
255 using namespace shamrock::patch;
257 bool cell_were_removed =
false;
260 auto dev_sched = shamsys::instance::get_compute_scheduler_ptr();
263 u32 old_obj_cnt = pdat.get_obj_cnt();
267 if (derefine_flags.
get_size() > 0) {
270 sham::DeviceBuffer<u32> keep_block_flag(old_obj_cnt, dev_sched);
271 keep_block_flag.fill(1);
273 sham::DeviceBuffer<TgridVec> &buf_cell_min = pdat.get_field_buf_ref<TgridVec>(0);
274 sham::DeviceBuffer<TgridVec> &buf_cell_max = pdat.get_field_buf_ref<TgridVec>(1);
276 sham::EventList depends_list;
277 auto block_bound_low = buf_cell_min.get_write_access(depends_list);
278 auto block_bound_high = buf_cell_max.get_write_access(depends_list);
279 UserAcc uacc(depends_list, pdat);
280 auto index_to_deref = derefine_flags.get_read_access(depends_list);
281 auto flag_keep = keep_block_flag.get_write_access(depends_list);
284 auto e = q.submit(depends_list, [&](sycl::handler &cgh) {
285 cgh.parallel_for(sycl::range<1>(derefine_flags.get_size()), [=](sycl::item<1> gid) {
286 u32 tid = gid.get_linear_id();
288 u32 idx_to_derefine = index_to_deref[gid];
291 std::array<u32, split_count> old_indexes;
293 for (u32 pid = 0; pid < split_count; pid++) {
294 old_indexes[pid] = idx_to_derefine + pid;
298 std::array<BlockCoord, split_count> block_coords;
300 for (u32 pid = 0; pid < split_count; pid++) {
301 block_coords[pid] = BlockCoord{
302 block_bound_low[old_indexes[pid]], block_bound_high[old_indexes[pid]]};
306 BlockCoord merged_block_coord = BlockCoord::get_merge(block_coords);
309 block_bound_low[idx_to_derefine] = merged_block_coord.bmin;
310 block_bound_high[idx_to_derefine] = merged_block_coord.bmax;
314 for (u32 pid = 1; pid < split_count; pid++) {
315 flag_keep[idx_to_derefine + pid] = 0;
320 old_indexes, block_coords, idx_to_derefine, merged_block_coord, uacc);
327 buf_cell_max.complete_event_state(resulting_events);
329 uacc.finalize(resulting_events, pdat);
331 keep_block_flag.complete_event_state(resulting_events);
342 "derefine block count ",
345 buf_keep.get_size());
347 if (buf_keep.get_size() == 0) {
348 throw std::runtime_error(
"buf keep must contain something at this point");
354 cell_were_removed = cell_were_removed || derefine_flags.
get_size() > 0;
358 return cell_were_removed;
361template<
class Tvec,
class Tgr
idVec>
362template<
class UserAccCrit,
class UserAccSplit,
class UserAccMerge>
367 AMRSortBlocks block_sorter(context, solver_config, storage);
368 block_sorter.reorder_amr_blocks();
374 gen_refine_block_changes<UserAccCrit>(dd_refine_list, dd_derefine_list);
378 internal_refine_grid<UserAccSplit>(std::move(dd_refine_list));
385 internal_derefine_grid<UserAccMerge>(std::move(dd_derefine_list));
388template<
class Tvec,
class Tgr
idVec>
392 class RefineCritBlock {
394 const TgridVec *block_low_bound;
395 const TgridVec *block_high_bound;
396 const Tscal *block_density_field;
398 Tscal one_over_Nside = 1. / AMRBlock::Nside;
410 : dxfact(dxfact), wanted_mass(wanted_mass) {
412 block_low_bound = pdat.get_field<TgridVec>(0).get_buf().get_read_access(depends_list);
413 block_high_bound = pdat.get_field<TgridVec>(1).get_buf().get_read_access(depends_list);
414 block_density_field = pdat.get_field<Tscal>(pdat.pdl().
get_field_idx<Tscal>(
"rho"))
416 .get_read_access(depends_list);
432 pdat.get_field<Tscal>(pdat.pdl().
get_field_idx<Tscal>(
"rho"))
434 .complete_event_state(resulting_events);
437 void refine_criterion(
438 u32 block_id, RefineCritBlock acc,
bool &should_refine,
bool &should_derefine)
const {
440 TgridVec low_bound = acc.block_low_bound[block_id];
441 TgridVec high_bound = acc.block_high_bound[block_id];
443 Tvec lower_flt = low_bound.template convert<Tscal>() * dxfact;
444 Tvec upper_flt = high_bound.template convert<Tscal>() * dxfact;
446 Tvec block_cell_size = (upper_flt - lower_flt) * one_over_Nside;
449 for (
u32 i = 0; i < AMRBlock::block_size; i++) {
450 sum_mass += acc.block_density_field[i + block_id * AMRBlock::block_size];
452 sum_mass *= block_cell_size.x() * block_cell_size.y() * block_cell_size.z();
454 if (sum_mass > wanted_mass * 8) {
455 should_refine =
true;
456 should_derefine =
false;
457 }
else if (sum_mass < wanted_mass) {
458 should_refine =
false;
459 should_derefine =
true;
461 should_refine =
false;
462 should_derefine =
false;
465 should_refine = should_refine && (high_bound.x() - low_bound.x() > AMRBlock::Nside);
466 should_refine = should_refine && (high_bound.y() - low_bound.y() > AMRBlock::Nside);
467 should_refine = should_refine && (high_bound.z() - low_bound.z() > AMRBlock::Nside);
471 class RefineCellAccessor {
479 rho = pdat.get_field<
f64>(2).get_buf().get_write_access(depends_list);
480 rho_vel = pdat.get_field<f64_3>(3).get_buf().get_write_access(depends_list);
481 rhoE = pdat.get_field<
f64>(4).get_buf().get_write_access(depends_list);
485 pdat.get_field<
f64>(2).get_buf().complete_event_state(resulting_events);
486 pdat.get_field<f64_3>(3).get_buf().complete_event_state(resulting_events);
487 pdat.get_field<
f64>(4).get_buf().complete_event_state(resulting_events);
492 BlockCoord cur_coords,
493 std::array<u32, 8> new_blocks,
494 std::array<BlockCoord, 8> new_block_coords,
495 RefineCellAccessor acc)
const {
497 auto get_coord_ref = [](
u32 i) -> std::array<u32, dim> {
498 constexpr u32 NsideBlockPow = 1;
499 constexpr u32 Nside = 1U << NsideBlockPow;
501 if constexpr (dim == 3) {
502 const u32 tmp = i >> NsideBlockPow;
503 return {i % Nside, (tmp) % Nside, (tmp) >> NsideBlockPow};
507 auto get_index_block = [](std::array<u32, dim> coord) ->
u32 {
508 constexpr u32 NsideBlockPow = 1;
509 constexpr u32 Nside = 1U << NsideBlockPow;
511 if constexpr (dim == 3) {
512 return coord[0] + Nside * coord[1] + Nside * Nside * coord[2];
516 auto get_gid_write = [&](std::array<u32, dim> &glid) ->
u32 {
517 std::array<u32, dim> bid
518 = {glid[0] >> AMRBlock::NsideBlockPow,
519 glid[1] >> AMRBlock::NsideBlockPow,
520 glid[2] >> AMRBlock::NsideBlockPow};
523 return new_blocks[get_index_block(bid)] * AMRBlock::block_size
524 + AMRBlock::get_index(
525 {glid[0] % AMRBlock::Nside,
526 glid[1] % AMRBlock::Nside,
527 glid[2] % AMRBlock::Nside});
530 std::array<f64, AMRBlock::block_size> old_rho_block;
531 std::array<f64_3, AMRBlock::block_size> old_rho_vel_block;
532 std::array<f64, AMRBlock::block_size> old_rhoE_block;
535 for (
u32 loc_id = 0; loc_id < AMRBlock::block_size; loc_id++) {
537 auto [lx, ly, lz] = get_coord_ref(loc_id);
538 u32 old_cell_idx = cur_idx * AMRBlock::block_size + loc_id;
539 old_rho_block[loc_id] = acc.rho[old_cell_idx];
540 old_rho_vel_block[loc_id] = acc.rho_vel[old_cell_idx];
541 old_rhoE_block[loc_id] = acc.rhoE[old_cell_idx];
544 for (
u32 loc_id = 0; loc_id < AMRBlock::block_size; loc_id++) {
546 auto [lx, ly, lz] = get_coord_ref(loc_id);
547 u32 old_cell_idx = cur_idx * AMRBlock::block_size + loc_id;
549 Tscal rho_block = old_rho_block[loc_id];
550 Tvec rho_vel_block = old_rho_vel_block[loc_id];
551 Tscal rhoE_block = old_rhoE_block[loc_id];
552 for (
u32 subdiv_lid = 0; subdiv_lid < 8; subdiv_lid++) {
554 auto [sx, sy, sz] = get_coord_ref(subdiv_lid);
556 std::array<u32, 3> glid = {lx * 2 + sx, ly * 2 + sy, lz * 2 + sz};
558 u32 new_cell_idx = get_gid_write(glid);
573 acc.rho[new_cell_idx] = rho_block;
574 acc.rho_vel[new_cell_idx] = rho_vel_block;
575 acc.rhoE[new_cell_idx] = rhoE_block;
581 std::array<u32, 8> old_blocks,
582 std::array<BlockCoord, 8> old_coords,
584 BlockCoord new_coord,
586 RefineCellAccessor acc)
const {
588 std::array<f64, AMRBlock::block_size> rho_block;
589 std::array<f64_3, AMRBlock::block_size> rho_vel_block;
590 std::array<f64, AMRBlock::block_size> rhoE_block;
592 for (
u32 cell_id = 0; cell_id < AMRBlock::block_size; cell_id++) {
593 rho_block[cell_id] = {};
594 rho_vel_block[cell_id] = {};
595 rhoE_block[cell_id] = {};
598 for (
u32 pid = 0; pid < 8; pid++) {
599 for (
u32 cell_id = 0; cell_id < AMRBlock::block_size; cell_id++) {
600 rho_block[cell_id] += acc.rho[old_blocks[pid] * AMRBlock::block_size + cell_id];
601 rho_vel_block[cell_id]
602 += acc.rho_vel[old_blocks[pid] * AMRBlock::block_size + cell_id];
604 += acc.rhoE[old_blocks[pid] * AMRBlock::block_size + cell_id];
608 for (
u32 cell_id = 0; cell_id < AMRBlock::block_size; cell_id++) {
609 rho_block[cell_id] /= 8;
610 rho_vel_block[cell_id] /= 8;
611 rhoE_block[cell_id] /= 8;
614 for (
u32 cell_id = 0; cell_id < AMRBlock::block_size; cell_id++) {
615 u32 newcell_idx = new_cell * AMRBlock::block_size + cell_id;
616 acc.rho[newcell_idx] = rho_block[cell_id];
617 acc.rho_vel[newcell_idx] = rho_vel_block[cell_id];
618 acc.rhoE[newcell_idx] = rhoE_block[cell_id];
623 using AMRmode_None =
typename AMRMode<Tvec, TgridVec>::None;
624 using AMRmode_DensityBased =
typename AMRMode<Tvec, TgridVec>::DensityBased;
626 bool has_cell_order_changed =
false;
628 if (AMRmode_None *cfg = std::get_if<AMRmode_None>(&solver_config.amr_mode.config)) {
631 AMRmode_DensityBased *cfg
632 = std::get_if<AMRmode_DensityBased>(&solver_config.amr_mode.config)) {
633 Tscal dxfact(solver_config.grid_coord_to_pos_fact);
639 gen_refine_block_changes<RefineCritBlock>(
640 refine_list, derefine_list, dxfact, cfg->crit_mass);
644 bool change_refine = internal_refine_grid<RefineCellAccessor>(std::move(refine_list));
651 bool change_derefine = internal_derefine_grid<RefineCellAccessor>(std::move(derefine_list));
653 has_cell_order_changed = has_cell_order_changed || (change_refine || change_derefine);
656 if (has_cell_order_changed) {
658 AMRSortBlocks block_sorter(context, solver_config, storage);
659 block_sorter.reorder_amr_blocks();
double f64
Alias for double.
std::uint32_t u32
32 bit unsigned integer
std::uint64_t u64
64 bit unsigned integer
A buffer allocated in USM (Unified Shared Memory)
void complete_event_state(sycl::event e) const
Complete the event state of the buffer.
T * get_write_access(sham::EventList &depends_list, SourceLocation src_loc=SourceLocation{})
Get a read-write pointer to the buffer's data.
size_t get_size() const
Gets the number of elements in the buffer.
const T * get_read_access(sham::EventList &depends_list, SourceLocation src_loc=SourceLocation{}) const
Get a read-only pointer to the buffer's data.
A SYCL queue associated with a device and a context.
sycl::event submit(Fct &&fct)
Submits a kernel to the SYCL queue.
DeviceQueue & get_queue(u32 id=0)
Get a reference to a DeviceQueue.
Class to manage a list of SYCL events.
void add_event(sycl::event e)
Add an event to the list of events.
Represents a collection of objects distributed across patches identified by a u64 id.
u32 get_field_idx(const std::string &field_name) const
Get the field id if matching name & type.
PatchDataLayer container class, the layout is described in patchdata_layout.
void index_remap_resize(sycl::buffer< u32 > &index_map, u32 len)
this function remaps the patchdatafield like so val[id] = val[index_map[id]] This function can be use...
main include file for the shamalgs algorithms
std::tuple< std::optional< sycl::buffer< u32 > >, u32 > stream_compact(sycl::queue &q, sycl::buffer< u32 > &buf_flags, u32 len)
Stream compaction algorithm.
std::vector< std::string_view > args
Executable argument list (mapped from argv)
Patch object that contain generic patch information.
u64 id_patch
unique key that identify the patch