Shamrock 2025.10.0
Astrophysical Code
Loading...
Searching...
No Matches
leapfrog.hpp
Go to the documentation of this file.
1// -------------------------------------------------------//
2//
3// SHAMROCK code for hydrodynamics
4// Copyright (c) 2021-2026 Timothée David--Cléris <tim.shamrock@proton.me>
5// SPDX-License-Identifier: CeCILL Free Software License Agreement v2.1
6// Shamrock is licensed under the CeCILL 2.1 License, see LICENSE for more information
7//
8// -------------------------------------------------------//
9
10#pragma once
11
18
19#include "shambackends/sycl.hpp"
20#include "shamrock/legacy/io/dump.hpp"
22#include "shamrock/legacy/patch/comm/patch_object_mover.hpp"
23#include "shamrock/legacy/patch/interfaces/interface_handler.hpp"
24#include "shamrock/legacy/patch/interfaces/interface_selector.hpp"
27#include "shamrock/legacy/patch/utility/merged_patch.hpp"
30// #include "shamrock/legacy/patch/patchdata_buffer.hpp"
31#include "shammodels/sph/legacy/algs/smoothing_length.hpp"
32#include "shammodels/sph/legacy/sphpatch.hpp"
33#include "shamrock/legacy/patch/base/patchdata_field.hpp"
36#include "shamrock/sph/kernels.hpp"
37#include "shamrock/sph/sphpart.hpp"
40#include <unordered_map>
41#include <filesystem>
42#include <memory>
43#include <string>
44#include <tuple>
45#include <vector>
46
47//%Impl status : Good
48
49namespace integrators {
50
51 namespace sph {
52
53 template<class flt, class Kernel, class u_morton>
54 class LeapfrogGeneral {
55 public:
56 using vec3 = sycl::vec<flt, 3>;
57
58 static_assert(
59 std::is_same<flt, f16>::value || std::is_same<flt, f32>::value
60 || std::is_same<flt, f64>::value,
61 "Leapfrog : floating point type should be one of (f16,f32,f64)");
62
63 inline static void sycl_move_parts(
64 sycl::queue &queue,
65 u32 npart,
66 flt dt,
67 const std::unique_ptr<sycl::buffer<vec3>> &buf_xyz,
68 const std::unique_ptr<sycl::buffer<vec3>> &buf_vxyz) {
69
70 sycl::range<1> range_npart{npart};
71
72 auto ker_predict_step = [&](sycl::handler &cgh) {
73 auto acc_xyz
74 = buf_xyz->template get_access<sycl::access::mode::read_write>(cgh);
75 auto acc_vxyz
76 = buf_vxyz->template get_access<sycl::access::mode::read_write>(cgh);
77
78 // Executing kernel
79 cgh.parallel_for(range_npart, [=](sycl::item<1> item) {
80 u32 gid = (u32) item.get_id();
81
82 vec3 &vxyz = acc_vxyz[item];
83
84 acc_xyz[item] = acc_xyz[item] + dt * vxyz;
85 });
86 };
87
88 queue.submit(ker_predict_step);
89 }
90
91 inline static void sycl_position_modulo(
92 sycl::queue &queue,
93 u32 npart,
94 const std::unique_ptr<sycl::buffer<vec3>> &buf_xyz,
95 std::tuple<vec3, vec3> box) {
96
97 sycl::range<1> range_npart{npart};
98
99 auto ker_predict_step = [&](sycl::handler &cgh) {
100 auto xyz = buf_xyz->template get_access<sycl::access::mode::read_write>(cgh);
101
102 vec3 box_min = std::get<0>(box);
103 vec3 box_max = std::get<1>(box);
104 vec3 delt = box_max - box_min;
105
106 // Executing kernel
107 cgh.parallel_for(range_npart, [=](sycl::item<1> item) {
108 u32 gid = (u32) item.get_id();
109
110 vec3 r = xyz[gid] - box_min;
111
112 r = sycl::fmod(r, delt);
113 r += delt;
114 r = sycl::fmod(r, delt);
115 r += box_min;
116
117 xyz[gid] = r;
118 });
119 };
120
121 queue.submit(ker_predict_step);
122 }
123
124 // mandatory variables
125 PatchScheduler &sched;
126 bool periodic_mode;
127 flt htol_up_tol;
128 flt htol_up_iter;
129
130 flt sph_gpart_mass;
131
132 LeapfrogGeneral(
133 PatchScheduler &sched,
134 bool periodic_mode,
135 flt htol_up_tol,
136 flt htol_up_iter,
137 flt sph_gpart_mass)
138 : sched(sched), periodic_mode(periodic_mode), htol_up_tol(htol_up_tol),
139 htol_up_iter(htol_up_iter), sph_gpart_mass(sph_gpart_mass) {}
140
141 template<
142 class LambdaCFL,
143 class LambdaUpdateTime,
144 class LambdaSwapDer,
145 class LambdaPostSync,
146 class LambdaForce,
147 class LambdaCorrector>
148 inline flt step(
149 flt old_time,
150 bool do_force,
151 bool do_corrector,
152 LambdaCFL &&lambda_cfl,
153 LambdaUpdateTime &&lambda_update_time,
154 LambdaSwapDer &&lambda_swap_der,
155 LambdaPostSync &&lambda_post_sync,
156 LambdaForce &&lambda_compute_forces,
157 LambdaCorrector &&lambda_correct) {
158
159 using namespace shamrock::patch;
160
161 const flt loc_htol_up_tol = htol_up_tol;
162 const flt loc_htol_up_iter = htol_up_iter;
163
164 const u32 ixyz = sched.pdl.get_field_idx<vec3>("xyz");
165 const u32 ivxyz = sched.pdl.get_field_idx<vec3>("vxyz");
166 const u32 iaxyz = sched.pdl.get_field_idx<vec3>("axyz");
167 const u32 iaxyz_old = sched.pdl.get_field_idx<vec3>("axyz_old");
168
169 const u32 ihpart = sched.pdl.get_field_idx<flt>("hpart");
170
172 "SPHLeapfrog",
173 "step t=",
174 old_time,
175 "do_force =",
176 do_force,
177 "do_corrector =",
178 do_corrector);
179
180 // Init serial patch tree
182 sched.patch_tree, sched.get_sim_box().template get_patch_transform<vec3>());
183 sptree.attach_buf();
184
185 // compute cfl
186 auto get_cfl = [&]() -> flt {
188 cfl_glb_var.compute_var_patch(sched, lambda_cfl);
189 cfl_glb_var.reduce_val();
190 return cfl_glb_var.get_val();
191 };
192
193 flt cfl_val = get_cfl();
194
195 // compute dt step
196
197 flt dt_cur = cfl_val;
198
199 logger::info_ln("SPHLeapfrog", "current dt :", dt_cur);
200
201 // advance time
202 flt step_time = old_time;
203 step_time += dt_cur;
204
205 // leapfrog predictor
206 sched.for_each_patch_data([&](u64 id_patch, Patch cur_p, PatchData &pdat) {
207 shamlog_debug_ln("SPHLeapfrog", "patch : n", id_patch, "->", "predictor");
208
209 lambda_update_time(
211 pdat,
212 sycl::range<1>{pdat.get_obj_cnt()},
213 dt_cur / 2);
214
215 sycl_move_parts(
217 pdat.get_obj_cnt(),
218 dt_cur,
219 pdat.get_field<vec3>(ixyz).get_buf(),
220 pdat.get_field<vec3>(ivxyz).get_buf());
221
222 lambda_update_time(
224 pdat,
225 sycl::range<1>{pdat.get_obj_cnt()},
226 dt_cur / 2);
227
228 shamlog_debug_ln("SPHLeapfrog", "patch : n", id_patch, "->", "dt fields swap");
229
230 lambda_swap_der(
232 pdat,
233 sycl::range<1>{pdat.get_obj_cnt()});
234
235 if (periodic_mode) { // TODO generalise position modulo in the scheduler
236 sycl_position_modulo(
238 pdat.get_obj_cnt(),
239 pdat.get_field<vec3>(ixyz).get_buf(),
240 sched.get_box_volume<vec3>());
241 }
242 });
243
244 // move particles between patches
245 shamlog_debug_ln("SPHLeapfrog", "particle reatribution");
246 reatribute_particles(sched, sptree, periodic_mode);
247
248 shamlog_debug_ln("SPHLeapfrog", "compute hmax of each patches");
249
250 // compute hmax
252 // sched.compute_patch_field(
253 // h_field, get_mpi_type<flt>(), [loc_htol_up_tol](sycl::queue &queue, Patch &p,
254 // PatchDataBuffer &pdat_buf) {
255 // return patchdata::sph::get_h_max<flt>(pdat_buf.pdl, queue, pdat_buf) *
256 // loc_htol_up_tol * Kernel::Rkern;
257 // });
258
259 sched.compute_patch_field(
260 h_field,
261 get_mpi_type<flt>(),
262 [loc_htol_up_tol](sycl::queue &queue, Patch &p, PatchData &pdat) {
263 return patchdata::sph::get_h_max<flt>(pdat.pdl, queue, pdat)
264 * loc_htol_up_tol * Kernel::Rkern;
265 });
266
267 shamlog_debug_ln("SPHLeapfrog", "compute interface list");
268 // make interfaces
269 LegacyInterfacehandler<vec3, flt> interface_hndl;
270 interface_hndl.template compute_interface_list<InterfaceSelector_SPH<vec3, flt>>(
271 sched, sptree, h_field, periodic_mode);
272
273 shamlog_debug_ln("SPHLeapfrog", "communicate interfaces");
274 interface_hndl.comm_interfaces(sched, periodic_mode);
275
276 shamlog_debug_ln("SPHLeapfrog", "merging interfaces with data");
277 // merging strategy
278
279 // old
280 // std::unordered_map<u64, MergedPatchDataBuffer<vec3>> merge_pdat_buf;
281 // make_merge_patches(sched, interface_hndl, merge_pdat_buf);
282
283 std::unordered_map<u64, MergedPatchData<flt>> merge_pdat
284 = MergedPatchData<flt>::merge_patches(sched, interface_hndl);
285
287
288 constexpr u32 reduc_level = 5;
289
290 // make trees
291 std::unordered_map<u64, std::unique_ptr<RadixTree<u_morton, vec3>>> radix_trees;
292 std::unordered_map<u64, std::unique_ptr<RadixTreeField<flt>>> cell_int_rads;
293
294 sched.for_each_patch([&](u64 id_patch, const Patch & /*cur_p*/) {
295 shamlog_debug_ln(
296 "SPHLeapfrog", "patch : n", id_patch, "->", "making Radix Tree");
297
298 if (merge_pdat.at(id_patch).or_element_cnt == 0)
299 shamlog_debug_ln(
300 "SPHLeapfrog",
301 "patch : n",
302 id_patch,
303 "->",
304 "is empty skipping tree build");
305
306 // PatchDataBuffer &mpdat_buf = *merge_pdat_buf.at(id_patch).data;
307 PatchData &mpdat = merge_pdat.at(id_patch).data;
308
309 auto &buf_xyz = mpdat.get_field<vec3>(ixyz).get_buf();
310
311 std::tuple<vec3, vec3> &box = merge_pdat.at(id_patch).box;
312
313 // radix tree computation
314 radix_trees[id_patch] = std::make_unique<RadixTree<u_morton, vec3>>(
316 box,
317 buf_xyz,
318 mpdat.get_obj_cnt(),
319 reduc_level);
320 });
321
322 sched.for_each_patch([&](u64 id_patch, Patch /*cur_p*/) {
323 shamlog_debug_ln(
324 "SPHLeapfrog",
325 "patch : n",
326 id_patch,
327 "->",
328 "compute radix tree cell volumes");
329 if (merge_pdat.at(id_patch).or_element_cnt == 0)
330 shamlog_debug_ln(
331 "SPHLeapfrog",
332 "patch : n",
333 id_patch,
334 "->",
335 "is empty skipping tree volumes step");
336
337 radix_trees[id_patch]->compute_cell_ibounding_box(
339 radix_trees[id_patch]->convert_bounding_box(
341 });
342
343 sched.for_each_patch([&](u64 id_patch, Patch /*cur_p*/) {
344 shamlog_debug_ln(
345 "SPHLeapfrog",
346 "patch : n",
347 id_patch,
348 "->",
349 "compute Radix Tree interaction boxes");
350 if (merge_pdat.at(id_patch).or_element_cnt == 0)
351 shamlog_debug_ln(
352 "SPHLeapfrog",
353 "patch : n",
354 id_patch,
355 "->",
356 "is empty skipping interaction box compute");
357
358 PatchData &mpdat = merge_pdat.at(id_patch).data;
359
360 auto &buf_h = mpdat.get_field<flt>(ihpart).get_buf();
361
362 cell_int_rads[id_patch] = std::make_unique<RadixTreeField<flt>>(
363 radix_trees[id_patch]->compute_int_boxes(
364 shamsys::instance::get_compute_queue(), buf_h, htol_up_tol));
365 });
367
368 // create compute field for new h and omega
369 shamlog_debug_ln("SPHLeapfrog", "init compute fields : hnew, omega");
370
371 PatchComputeField<flt> hnew_field;
372 PatchComputeField<flt> omega_field;
373
374 hnew_field.generate(sched);
375 omega_field.generate(sched);
376
377 // iterate smoothing length
378 sched.for_each_patch([&](u64 id_patch, Patch /*cur_p*/) {
379 shamlog_debug_ln(
380 "SPHLeapfrog", "patch : n", id_patch, "->", "Init h iteration");
381 if (merge_pdat.at(id_patch).or_element_cnt == 0)
382 shamlog_debug_ln(
383 "SPHLeapfrog",
384 "patch : n",
385 id_patch,
386 "->",
387 "is empty skipping h iteration");
388
389 PatchData &pdat_merge = merge_pdat.at(id_patch).data;
390
391 auto &hnew = hnew_field.get_buf(id_patch);
392 auto &omega = omega_field.get_buf(id_patch);
393 sycl::buffer<flt> eps_h
394 = sycl::buffer<flt>(merge_pdat.at(id_patch).or_element_cnt);
395
396 sycl::range range_npart{merge_pdat.at(id_patch).or_element_cnt};
397
398 shamlog_debug_ln(
399 "SPHLeapfrog",
400 "merging -> original size :",
401 merge_pdat.at(id_patch).or_element_cnt,
402 "| merged :",
403 pdat_merge.get_obj_cnt());
404
406 sched.pdl, htol_up_tol, htol_up_iter);
407
408 h_iterator.iterate_smoothing_length(
410 merge_pdat.at(id_patch).or_element_cnt,
411 sph_gpart_mass,
412 *radix_trees[id_patch],
413 *cell_int_rads[id_patch],
414 pdat_merge,
415 *hnew,
416 *omega,
417 eps_h);
418
419 // write back h test
420 //*
421 shamsys::instance::get_compute_queue().submit([&](sycl::handler &cgh) {
422 auto h_new = hnew->template get_access<sycl::access::mode::read>(cgh);
423
424 auto acc_hpart = pdat_merge.get_field<flt>(ihpart)
425 .get_buf()
426 ->template get_access<sycl::access::mode::write>(cgh);
427
428 cgh.parallel_for(range_npart, [=](sycl::item<1> item) {
429 acc_hpart[item] = h_new[item];
430 });
431 });
432 //*/
433 });
434
435 shamlog_debug_ln("SPHLeapfrog", "exchange interface hnew");
436 PatchComputeFieldInterfaces<flt> hnew_field_interfaces
437 = interface_hndl.template comm_interfaces_field<flt>(
438 sched, hnew_field, periodic_mode);
439 shamlog_debug_ln("SPHLeapfrog", "exchange interface omega");
440 PatchComputeFieldInterfaces<flt> omega_field_interfaces
441 = interface_hndl.template comm_interfaces_field<flt>(
442 sched, omega_field, periodic_mode);
443
444 // merge compute fields
445 // std::unordered_map<u64, MergedPatchCompFieldBuffer<flt>> hnew_field_merged;
446 // make_merge_patches_comp_field<flt>(sched, interface_hndl, hnew_field,
447 // hnew_field_interfaces, hnew_field_merged); std::unordered_map<u64,
448 // MergedPatchCompFieldBuffer<flt>> omega_field_merged;
449 // make_merge_patches_comp_field<flt>(sched, interface_hndl, omega_field,
450 // omega_field_interfaces, omega_field_merged);
451
452 std::unordered_map<u64, MergedPatchCompField<flt, flt>> hnew_field_merged
453 = MergedPatchCompField<flt, flt>::merge_patches_cfield(
454 sched, interface_hndl, hnew_field, hnew_field_interfaces);
455
456 std::unordered_map<u64, MergedPatchCompField<flt, flt>> omega_field_merged
457 = MergedPatchCompField<flt, flt>::merge_patches_cfield(
458 sched, interface_hndl, omega_field, omega_field_interfaces);
459
460 // TODO add looping on corrector step
461
462 lambda_post_sync(sched, merge_pdat, hnew_field_merged, omega_field_merged);
463
464 // compute force
465 if (do_force) {
466 lambda_compute_forces(
467 sched,
468 radix_trees,
469 cell_int_rads,
470 merge_pdat,
471 hnew_field_merged,
472 omega_field_merged,
473 htol_up_tol);
474 }
475
476 // leapfrog corrector
477 sched.for_each_patch([&](u64 id_patch, Patch /*cur_p*/) {
478 if (merge_pdat.at(id_patch).or_element_cnt == 0) {
479 std::cout << " empty => skipping" << std::endl;
480 return;
481 }
482
483 PatchData &pdat_merge = merge_pdat.at(id_patch).data;
484
485 if (do_corrector) {
486
487 shamlog_debug_ln("SPHLeapfrog", "leapfrog corrector");
488
489 lambda_correct(
491 pdat_merge,
492 sycl::range<1>{merge_pdat.at(id_patch).or_element_cnt},
493 dt_cur / 2);
494 }
495 });
496
497 // write_back_merge_patches(sched, interface_hndl, merge_pdat);
498 write_back_merge_patches(sched, merge_pdat);
499
500 return step_time;
501 }
502 };
503
504 } // namespace sph
505
506} // namespace integrators
sycl::queue & get_compute_queue(u32 id=0)
MPI scheduler.
std::uint32_t u32
32 bit unsigned integer
std::uint64_t u64
64 bit unsigned integer
The MPI scheduler.
Define a field attached to a patch (example: FMM multipoles, hmax in SPH).
header for PatchData related function and declaration
void info_ln(std::string module_name, Types... var2)
Prints a log message with multiple arguments followed by a newline.
Definition logs.hpp:133
Patch object that contain generic patch information.
Definition Patch.hpp:33