54 class LeapfrogGeneral {
56 using vec3 = sycl::vec<flt, 3>;
59 std::is_same<flt, f16>::value || std::is_same<flt, f32>::value
60 || std::is_same<flt, f64>::value,
61 "Leapfrog : floating point type should be one of (f16,f32,f64)");
63 inline static void sycl_move_parts(
67 const std::unique_ptr<sycl::buffer<vec3>> &buf_xyz,
68 const std::unique_ptr<sycl::buffer<vec3>> &buf_vxyz) {
70 sycl::range<1> range_npart{npart};
72 auto ker_predict_step = [&](sycl::handler &cgh) {
74 = buf_xyz->template get_access<sycl::access::mode::read_write>(cgh);
76 = buf_vxyz->template get_access<sycl::access::mode::read_write>(cgh);
79 cgh.parallel_for(range_npart, [=](sycl::item<1> item) {
80 u32 gid = (
u32) item.get_id();
82 vec3 &vxyz = acc_vxyz[item];
84 acc_xyz[item] = acc_xyz[item] + dt * vxyz;
88 queue.submit(ker_predict_step);
91 inline static void sycl_position_modulo(
94 const std::unique_ptr<sycl::buffer<vec3>> &buf_xyz,
95 std::tuple<vec3, vec3> box) {
97 sycl::range<1> range_npart{npart};
99 auto ker_predict_step = [&](sycl::handler &cgh) {
100 auto xyz = buf_xyz->template get_access<sycl::access::mode::read_write>(cgh);
102 vec3 box_min = std::get<0>(box);
103 vec3 box_max = std::get<1>(box);
104 vec3 delt = box_max - box_min;
107 cgh.parallel_for(range_npart, [=](sycl::item<1> item) {
108 u32 gid = (
u32) item.get_id();
110 vec3 r = xyz[gid] - box_min;
112 r = sycl::fmod(r, delt);
114 r = sycl::fmod(r, delt);
121 queue.submit(ker_predict_step);
138 : sched(sched), periodic_mode(periodic_mode), htol_up_tol(htol_up_tol),
139 htol_up_iter(htol_up_iter), sph_gpart_mass(sph_gpart_mass) {}
143 class LambdaUpdateTime,
145 class LambdaPostSync,
147 class LambdaCorrector>
152 LambdaCFL &&lambda_cfl,
153 LambdaUpdateTime &&lambda_update_time,
154 LambdaSwapDer &&lambda_swap_der,
155 LambdaPostSync &&lambda_post_sync,
156 LambdaForce &&lambda_compute_forces,
157 LambdaCorrector &&lambda_correct) {
159 using namespace shamrock::patch;
161 const flt loc_htol_up_tol = htol_up_tol;
162 const flt loc_htol_up_iter = htol_up_iter;
164 const u32 ixyz = sched.pdl.get_field_idx<vec3>(
"xyz");
165 const u32 ivxyz = sched.pdl.get_field_idx<vec3>(
"vxyz");
166 const u32 iaxyz = sched.pdl.get_field_idx<vec3>(
"axyz");
167 const u32 iaxyz_old = sched.pdl.get_field_idx<vec3>(
"axyz_old");
169 const u32 ihpart = sched.pdl.get_field_idx<flt>(
"hpart");
182 sched.patch_tree, sched.get_sim_box().template get_patch_transform<vec3>());
186 auto get_cfl = [&]() -> flt {
188 cfl_glb_var.compute_var_patch(sched, lambda_cfl);
189 cfl_glb_var.reduce_val();
190 return cfl_glb_var.get_val();
193 flt cfl_val = get_cfl();
197 flt dt_cur = cfl_val;
202 flt step_time = old_time;
206 sched.for_each_patch_data([&](
u64 id_patch,
Patch cur_p, PatchData &pdat) {
207 shamlog_debug_ln(
"SPHLeapfrog",
"patch : n", id_patch,
"->",
"predictor");
212 sycl::range<1>{pdat.get_obj_cnt()},
219 pdat.get_field<vec3>(ixyz).get_buf(),
220 pdat.get_field<vec3>(ivxyz).get_buf());
225 sycl::range<1>{pdat.get_obj_cnt()},
228 shamlog_debug_ln(
"SPHLeapfrog",
"patch : n", id_patch,
"->",
"dt fields swap");
233 sycl::range<1>{pdat.get_obj_cnt()});
236 sycl_position_modulo(
239 pdat.get_field<vec3>(ixyz).get_buf(),
240 sched.get_box_volume<vec3>());
245 shamlog_debug_ln(
"SPHLeapfrog",
"particle reatribution");
246 reatribute_particles(sched, sptree, periodic_mode);
248 shamlog_debug_ln(
"SPHLeapfrog",
"compute hmax of each patches");
259 sched.compute_patch_field(
262 [loc_htol_up_tol](sycl::queue &queue,
Patch &p, PatchData &pdat) {
263 return patchdata::sph::get_h_max<flt>(pdat.pdl, queue, pdat)
264 * loc_htol_up_tol * Kernel::Rkern;
267 shamlog_debug_ln(
"SPHLeapfrog",
"compute interface list");
269 LegacyInterfacehandler<vec3, flt> interface_hndl;
270 interface_hndl.template compute_interface_list<InterfaceSelector_SPH<vec3, flt>>(
271 sched, sptree, h_field, periodic_mode);
273 shamlog_debug_ln(
"SPHLeapfrog",
"communicate interfaces");
274 interface_hndl.comm_interfaces(sched, periodic_mode);
276 shamlog_debug_ln(
"SPHLeapfrog",
"merging interfaces with data");
283 std::unordered_map<u64, MergedPatchData<flt>> merge_pdat
288 constexpr u32 reduc_level = 5;
291 std::unordered_map<u64, std::unique_ptr<RadixTree<u_morton, vec3>>> radix_trees;
292 std::unordered_map<u64, std::unique_ptr<RadixTreeField<flt>>> cell_int_rads;
294 sched.for_each_patch([&](
u64 id_patch,
const Patch & ) {
296 "SPHLeapfrog",
"patch : n", id_patch,
"->",
"making Radix Tree");
298 if (merge_pdat.at(id_patch).or_element_cnt == 0)
304 "is empty skipping tree build");
307 PatchData &mpdat = merge_pdat.at(id_patch).data;
309 auto &buf_xyz = mpdat.get_field<vec3>(ixyz).get_buf();
311 std::tuple<vec3, vec3> &box = merge_pdat.at(id_patch).box;
314 radix_trees[id_patch] = std::make_unique<RadixTree<u_morton, vec3>>(
322 sched.for_each_patch([&](
u64 id_patch,
Patch ) {
328 "compute radix tree cell volumes");
329 if (merge_pdat.at(id_patch).or_element_cnt == 0)
335 "is empty skipping tree volumes step");
337 radix_trees[id_patch]->compute_cell_ibounding_box(
339 radix_trees[id_patch]->convert_bounding_box(
343 sched.for_each_patch([&](
u64 id_patch,
Patch ) {
349 "compute Radix Tree interaction boxes");
350 if (merge_pdat.at(id_patch).or_element_cnt == 0)
356 "is empty skipping interaction box compute");
358 PatchData &mpdat = merge_pdat.at(id_patch).data;
360 auto &buf_h = mpdat.get_field<flt>(ihpart).get_buf();
362 cell_int_rads[id_patch] = std::make_unique<RadixTreeField<flt>>(
363 radix_trees[id_patch]->compute_int_boxes(
369 shamlog_debug_ln(
"SPHLeapfrog",
"init compute fields : hnew, omega");
374 hnew_field.generate(sched);
375 omega_field.generate(sched);
378 sched.for_each_patch([&](
u64 id_patch,
Patch ) {
380 "SPHLeapfrog",
"patch : n", id_patch,
"->",
"Init h iteration");
381 if (merge_pdat.at(id_patch).or_element_cnt == 0)
387 "is empty skipping h iteration");
389 PatchData &pdat_merge = merge_pdat.at(id_patch).data;
391 auto &hnew = hnew_field.get_buf(id_patch);
392 auto &omega = omega_field.get_buf(id_patch);
393 sycl::buffer<flt> eps_h
394 = sycl::buffer<flt>(merge_pdat.at(id_patch).or_element_cnt);
396 sycl::range range_npart{merge_pdat.at(id_patch).or_element_cnt};
400 "merging -> original size :",
401 merge_pdat.at(id_patch).or_element_cnt,
403 pdat_merge.get_obj_cnt());
406 sched.pdl, htol_up_tol, htol_up_iter);
408 h_iterator.iterate_smoothing_length(
410 merge_pdat.at(id_patch).or_element_cnt,
412 *radix_trees[id_patch],
413 *cell_int_rads[id_patch],
422 auto h_new = hnew->template get_access<sycl::access::mode::read>(cgh);
424 auto acc_hpart = pdat_merge.get_field<flt>(ihpart)
426 ->template get_access<sycl::access::mode::write>(cgh);
428 cgh.parallel_for(range_npart, [=](sycl::item<1> item) {
429 acc_hpart[item] = h_new[item];
435 shamlog_debug_ln(
"SPHLeapfrog",
"exchange interface hnew");
437 = interface_hndl.template comm_interfaces_field<flt>(
438 sched, hnew_field, periodic_mode);
439 shamlog_debug_ln(
"SPHLeapfrog",
"exchange interface omega");
441 = interface_hndl.template comm_interfaces_field<flt>(
442 sched, omega_field, periodic_mode);
452 std::unordered_map<u64, MergedPatchCompField<flt, flt>> hnew_field_merged
453 = MergedPatchCompField<flt, flt>::merge_patches_cfield(
454 sched, interface_hndl, hnew_field, hnew_field_interfaces);
456 std::unordered_map<u64, MergedPatchCompField<flt, flt>> omega_field_merged
457 = MergedPatchCompField<flt, flt>::merge_patches_cfield(
458 sched, interface_hndl, omega_field, omega_field_interfaces);
462 lambda_post_sync(sched, merge_pdat, hnew_field_merged, omega_field_merged);
466 lambda_compute_forces(
477 sched.for_each_patch([&](
u64 id_patch,
Patch ) {
478 if (merge_pdat.at(id_patch).or_element_cnt == 0) {
479 std::cout <<
" empty => skipping" << std::endl;
483 PatchData &pdat_merge = merge_pdat.at(id_patch).data;
487 shamlog_debug_ln(
"SPHLeapfrog",
"leapfrog corrector");
492 sycl::range<1>{merge_pdat.at(id_patch).or_element_cnt},
498 write_back_merge_patches(sched, merge_pdat);