Shamrock 2025.10.0
Astrophysical Code
Loading...
Searching...
No Matches
SPHSetup.cpp
Go to the documentation of this file.
1// -------------------------------------------------------//
2//
3// SHAMROCK code for hydrodynamics
4// Copyright (c) 2021-2026 Timothée David--Cléris <tim.shamrock@proton.me>
5// SPDX-License-Identifier: CeCILL Free Software License Agreement v2.1
6// Shamrock is licensed under the CeCILL 2.1 License, see LICENSE for more information
7//
8// -------------------------------------------------------//
9
20#include "shambase/memory.hpp"
21#include "shambase/string.hpp"
22#include "shambase/tabulate.hpp"
28#include "shamcomm/logs.hpp"
30#include "shamcomm/wrapper.hpp"
47#include <mpi.h>
48#include <vector>
49
50template<class Tvec, template<class> class SPHKernel>
51inline std::shared_ptr<shammodels::sph::modules::ISPHSetupNode> shammodels::sph::modules::
52 SPHSetup<Tvec, SPHKernel>::make_generator_lattice_hcp(
53 Tscal dr, std::pair<Tvec, Tvec> box, bool discontinuous) {
54 if (discontinuous) {
55 return std::shared_ptr<ISPHSetupNode>(
56 new GeneratorLatticeHCP<Tvec, true>(context, dr, box));
57 } else {
58 return std::shared_ptr<ISPHSetupNode>(
59 new GeneratorLatticeHCP<Tvec, false>(context, dr, box));
60 }
61}
62
63template<class Tvec, template<class> class SPHKernel>
64inline std::shared_ptr<shammodels::sph::modules::ISPHSetupNode> shammodels::sph::modules::
65 SPHSetup<Tvec, SPHKernel>::make_generator_lattice_cubic(Tscal dr, std::pair<Tvec, Tvec> box) {
66 return std::shared_ptr<ISPHSetupNode>(new GeneratorLatticeCubic<Tvec>(context, dr, box));
67}
68
69template<class Tvec, template<class> class SPHKernel>
70inline std::shared_ptr<shammodels::sph::modules::ISPHSetupNode> shammodels::sph::modules::
71 SPHSetup<Tvec, SPHKernel>::make_generator_disc_mc(
72 Tscal part_mass,
73 Tscal disc_mass,
74 Tscal r_in,
75 Tscal r_out,
76 std::function<Tscal(Tscal)> sigma_profile,
77 std::function<Tscal(Tscal)> H_profile,
78 std::function<Tscal(Tscal)> rot_profile,
79 std::function<Tscal(Tscal)> cs_profile,
80 std::mt19937_64 eng,
81 Tscal init_h_factor) {
82 return std::shared_ptr<ISPHSetupNode>(new GeneratorMCDisc<Tvec, SPHKernel>(
83 context,
84 solver_config,
85 part_mass,
86 disc_mass,
87 r_in,
88 r_out,
89 sigma_profile,
90 H_profile,
91 rot_profile,
92 cs_profile,
93 eng,
94 init_h_factor));
95}
96
97template<class Tvec, template<class> class SPHKernel>
98inline std::shared_ptr<shammodels::sph::modules::ISPHSetupNode> shammodels::sph::modules::
99 SPHSetup<Tvec, SPHKernel>::make_generator_from_context(ShamrockCtx &context_other) {
100 return std::shared_ptr<ISPHSetupNode>(
101 new GeneratorFromOtherContext<Tvec>(context, context_other));
102}
103
104template<class Tvec, template<class> class SPHKernel>
105inline std::shared_ptr<shammodels::sph::modules::ISPHSetupNode> shammodels::sph::modules::
106 SPHSetup<Tvec, SPHKernel>::make_combiner_add(SetupNodePtr parent1, SetupNodePtr parent2) {
107 return std::shared_ptr<ISPHSetupNode>(new CombinerAdd<Tvec>(context, parent1, parent2));
108}
109
110template<class Tvec, template<class> class SPHKernel>
112 SetupNodePtr setup, bool part_reordering, std::optional<u32> insert_step) {
113
114 if (!bool(setup)) {
115 shambase::throw_with_loc<std::invalid_argument>("The setup shared pointer is empty");
116 }
117
118 shambase::Timer time_setup;
119 time_setup.start();
120 StackEntry stack_loc{};
121
122 PatchScheduler &sched = shambase::get_check_ref(context.sched);
123
124 auto compute_load = [&]() {
125 modules::ComputeLoadBalanceValue<Tvec, SPHKernel>(context, solver_config, storage)
126 .update_load_balancing();
127 };
128
129 auto has_pdat = [&]() {
130 bool ret = false;
131 using namespace shamrock::patch;
132 sched.for_each_local_patchdata([&](const Patch &p, PatchDataLayer &pdat) {
133 ret = true;
134 });
135 return ret;
136 };
137
138 shamrock::DataInserterUtility inserter(sched);
139 u32 _insert_step = sched.crit_patch_split * 8;
140 if (bool(insert_step)) {
141 _insert_step = insert_step.value();
142 }
143
144 while (!setup->is_done()) {
145
146 shamrock::patch::PatchDataLayer pdat = setup->next_n((has_pdat()) ? _insert_step : 0);
147
148 if (solver_config.track_particles_id) {
149 // This bit set the tracking id of the particles
150 // But be carefull this assume that the particle injection order
151 // is independant from the MPI world size. It should be the case for most setups
152 // but some generator could miss this assumption.
153 // If that is the case please report the issue
154
155 u64 loc_inj = pdat.get_obj_cnt();
156
157 u64 offset_init = 0;
159 &loc_inj, &offset_init, 1, get_mpi_type<u64>(), MPI_SUM, MPI_COMM_WORLD);
160
161 // we must add the number of already injected part such that the
162 // offset start at the right spot.
163 // The only thing that bothers me is that this can not handle the case where multiple
164 // setups of things like that are applied. But in principle no sane person would do such
165 // a thing...
166 offset_init += injected_parts;
167
168 auto dev_sched = shamsys::instance::get_compute_scheduler_ptr();
169 auto &q = shambase::get_check_ref(dev_sched).get_queue();
170
171 if (loc_inj > 0) {
172 sham::DeviceBuffer<u64> part_ids(loc_inj, dev_sched);
173
175 q,
177 sham::MultiRef{part_ids},
178 loc_inj,
179 [offset_init](u32 i, u64 *__restrict part_ids) {
180 part_ids[i] = i + offset_init;
181 });
182
183 pdat.get_field<u64>(pdat.pdl().get_field_idx<u64>("part_id"))
184 .overwrite(part_ids, loc_inj);
185 }
186 }
187
188 u64 injected
189 = inserter.push_patch_data<Tvec>(pdat, "xyz", sched.crit_patch_split * 8, compute_load);
190
191 injected_parts += injected;
192 }
193
194 u32 final_balancing_steps = 3;
195 for (u32 i = 0; i < final_balancing_steps; i++) {
196 ON_RANK_0(
197 logger::info_ln(
198 "SPH setup", "Final load balancing step", i, "of", final_balancing_steps));
199 inserter.balance_load(compute_load);
200 }
201
202 if (part_reordering) {
203 modules::ParticleReordering<Tvec, u32, SPHKernel>(context, solver_config, storage)
204 .reorder_particles();
205 }
206
207 time_setup.end();
208 if (shamcomm::world_rank() == 0) {
209 logger::info_ln("SPH setup", "the setup took :", time_setup.elasped_sec(), "s");
210 }
211}
212
213struct SetupLog {
214 struct State {
215 std::vector<u64> count_per_rank;
216 std::vector<std::tuple<u32, u32, u64>> msg_list;
217 } state;
218
219 u64 step_counter = 0;
220
221 nlohmann::json json_data = nlohmann::json::array();
222
223 void log_state() {
224 nlohmann::json step_data;
225 step_data["step_counter"] = step_counter;
226 step_data["count_per_rank"] = state.count_per_rank;
227 step_data["msg_list"] = state.msg_list;
228 json_data.push_back(step_data);
229 }
230
231 void dump_state() {
232 std::string fname = "setup_log_step.json";
233 if (shamcomm::world_rank() == 0) {
234 logger::normal_ln("SPH setup", "dumping setup log to ", fname);
235 }
236
237 std::ofstream file(fname);
238 file << json_data.dump(4);
239 file.close();
240
241 step_counter++;
242 }
243
244 void update_count_per_rank(u64 count) {
245 std::vector<u64> tmp{count};
246 std::vector<u64> recv_count_per_rank;
247 shamalgs::collective::vector_allgatherv(tmp, recv_count_per_rank, MPI_COMM_WORLD);
248 state.count_per_rank = recv_count_per_rank;
249 log_state();
250 if (step_counter % 20 == 0)
251 dump_state();
252 }
253
254 void update_msg_list(std::vector<std::tuple<u32, u32, u64>> &msg_list) {
255 state.msg_list = msg_list;
256 log_state();
257 if (step_counter % 20 == 0)
258 dump_state();
259 }
260};
261
262inline constexpr f64 golden_number = 1.61803398874989484820458683436563;
263
264template<class Tvec, template<class> class SPHKernel>
266 SetupNodePtr setup,
267 bool part_reordering,
268 std::optional<u32> gen_count_per_step,
269 std::optional<u32> insert_count_per_step,
270 std::optional<u64> max_msg_count_per_rank_per_step,
271 std::optional<u64> max_data_count_per_rank_per_step,
272 std::optional<u64> max_msg_size,
273 bool do_setup_log,
274 bool speculative_balancing) {
275
277
278 if (!bool(setup)) {
279 shambase::throw_with_loc<std::invalid_argument>("The setup shared pointer is empty");
280 }
281
282 std::optional<SetupLog> setup_log
283 = (do_setup_log) ? std::make_optional<SetupLog>() : std::nullopt;
284
285 shambase::Timer time_setup;
286 time_setup.start();
287 PatchScheduler &sched = shambase::get_check_ref(context.sched);
288 shamrock::DataInserterUtility inserter(sched);
289
290 u32 insert_step = sched.crit_patch_split * 2;
291 if (bool(insert_count_per_step)) {
292 insert_step = insert_count_per_step.value();
293 }
294
295 u32 gen_step = std::max(sched.crit_patch_split / 8, 1_u64);
296 if (bool(gen_count_per_step)) {
297 gen_step = gen_count_per_step.value();
298 }
299
300 u64 msg_limit = 1024;
301 if (bool(max_msg_count_per_rank_per_step)) {
302 msg_limit = max_msg_count_per_rank_per_step.value();
303 }
304 u64 data_count_limit = insert_step;
305 if (bool(max_data_count_per_rank_per_step)) {
306 data_count_limit = max_data_count_per_rank_per_step.value();
307 }
308 u64 max_message_size = std::max(insert_step / 16, 1_u32);
309 if (bool(max_msg_size)) {
310 max_message_size = max_msg_size.value();
311 }
312
313 shamrock::patch::PatchDataLayer to_insert(sched.get_layout_ptr_old());
314
315 u64 speculative_last_npatch = 0;
316 shambase::DistributedData<u64> speculative_load_values = {};
317
318 auto compute_load = [&]() {
319 if (speculative_balancing) {
320
321 StackEntry stack_loc{};
322
323 auto dev_sched = shamsys::instance::get_compute_scheduler_ptr();
324
325 u64 npatch = scheduler().patch_list.global.size();
326
327 // check if the number of patches has changed, rebuild otherwise
328 if (npatch != speculative_last_npatch) {
329
330 shambase::details::NamedBasicStackEntry stack_loc2{"compute_load"};
331
332 if (shamcomm::world_rank() == 0) {
333 logger::normal_ln(
334 "SPH setup",
335 "number of patches has changed, rebuilding speculative load values");
336 }
337
338 // reset the load values
339 speculative_last_npatch = npatch;
340 speculative_load_values.reset();
341
342 // Compute the AABB of all the patches
343
344 std::vector<Tvec> patch_aabb_min(npatch);
345 std::vector<Tvec> patch_aabb_max(npatch);
346
347 auto &global_patch_list = scheduler().patch_list.global;
349 = sched.get_sim_box().get_patch_transform<Tvec>();
350
351 for (size_t i = 0; i < global_patch_list.size(); i++) {
352 const shamrock::patch::Patch &p = global_patch_list[i];
353 if (!p.is_err_mode()) {
354 shammath::CoordRange<Tvec> patch_coord = ptransf.to_obj_coord(p);
355 patch_aabb_min[i] = patch_coord.lower;
356 patch_aabb_max[i] = patch_coord.upper;
357 }
358 }
359
360 sham::DeviceBuffer<Tvec> buf_patch_aabb_min(npatch, dev_sched);
361 sham::DeviceBuffer<Tvec> buf_patch_aabb_max(npatch, dev_sched);
362
363 buf_patch_aabb_min.copy_from_stdvec(patch_aabb_min);
364 buf_patch_aabb_max.copy_from_stdvec(patch_aabb_max);
365
366 // count the number of particles in each patch
367
368 sham::DeviceBuffer<u64> local_load_values(npatch, dev_sched);
369 local_load_values.fill(0);
370
371 PatchDataField<Tvec> &xyz = to_insert.get_field<Tvec>(0);
372
373 if (xyz.get_obj_cnt() > 0) {
375 shamsys::instance::get_compute_scheduler().get_queue(),
376 sham::MultiRef{xyz.get_buf(), buf_patch_aabb_min, buf_patch_aabb_max},
377 sham::MultiRef{local_load_values},
378 xyz.get_obj_cnt(),
379 [npatch](
380 u32 i,
381 const Tvec *__restrict xyz,
382 const Tvec *__restrict patch_aabb_min,
383 const Tvec *__restrict patch_aabb_max,
384 u64 *__restrict local_load_values) {
385 Tvec pos = xyz[i];
386 for (size_t j = 0; j < npatch; j++) {
388 = {patch_aabb_min[j], patch_aabb_max[j]};
389 if (patch_coord.contain_pos(pos)) {
390 sycl::atomic_ref<
391 u64,
392 sycl::memory_order::relaxed,
393 sycl::memory_scope::device>
394 atomic_local_load_values(local_load_values[j]);
395 atomic_local_load_values++;
396 }
397 }
398 });
399 }
400
401 // recover data
402
403 auto local_load_values_host = local_load_values.copy_to_stdvec();
404
405 std::vector<u64> reduced_load_values(npatch);
406
407 // reduce the load values
408
410 local_load_values_host.data(),
411 reduced_load_values.data(),
412 npatch,
413 get_mpi_type<u64>(),
414 MPI_SUM,
415 MPI_COMM_WORLD);
416
417 // convert to DistributedData
418
419 for (size_t i = 0; i < npatch; i++) {
420 speculative_load_values.add_obj(
421 global_patch_list[i].id_patch, u64(reduced_load_values[i]));
422 }
423
424 // Add the already injected parts to the load values
425
426 auto &patch_list = scheduler().patch_list;
427
428 for (u64 id : scheduler().owned_patch_id) {
430 = patch_list.local[patch_list.id_patch_to_local_idx[id]];
431 speculative_load_values.get(id)
432 += scheduler().patch_data.owned_data.get(id).get_obj_cnt();
433 }
434 }
435
436 // update load values
437
438 scheduler().update_local_load_value([&](shamrock::patch::Patch p) {
439 return speculative_load_values.get(p.id_patch);
440 });
441
442 } else {
443 modules::ComputeLoadBalanceValue<Tvec, SPHKernel>(context, solver_config, storage)
444 .update_load_balancing();
445 }
446 };
447
448 auto has_pdat = [&]() {
449 bool ret = false;
450 using namespace shamrock::patch;
451 sched.for_each_local_patchdata([&](const Patch &p, PatchDataLayer &pdat) {
452 ret = true;
453 });
454 return ret;
455 };
456
457 shambase::Timer time_part_gen;
458 time_part_gen.start();
459
460 if (shamcomm::world_rank() == 0) {
461 logger::normal_ln("SPH setup", "generating particles ...");
462 }
463
464 while (!setup->is_done()) {
465 shambase::Timer timer_gen;
466 timer_gen.start();
467
468 shamrock::patch::PatchDataLayer tmp = setup->next_n(gen_step);
469
470 if (solver_config.track_particles_id) {
471 // This bit set the tracking id of the particles
472 // But be carefull this assume that the particle injection order
473 // is independant from the MPI world size. It should be the case for most setups
474 // but some generator could miss this assumption.
475 // If that is the case please report the issue
476
477 u64 loc_inj = tmp.get_obj_cnt();
478
479 u64 offset_init = 0;
481 &loc_inj, &offset_init, 1, get_mpi_type<u64>(), MPI_SUM, MPI_COMM_WORLD);
482
483 // we must add the number of already injected part such that the
484 // offset start at the right spot.
485 // The only thing that bothers me is that this can not handle the case where multiple
486 // setups of things like that are applied. But in principle no sane person would do such
487 // a thing...
488 offset_init += injected_parts;
489
490 auto dev_sched = shamsys::instance::get_compute_scheduler_ptr();
491 auto &q = shambase::get_check_ref(dev_sched).get_queue();
492
493 if (loc_inj > 0) {
494 sham::DeviceBuffer<u64> part_ids(loc_inj, dev_sched);
495
497 q,
499 sham::MultiRef{part_ids},
500 loc_inj,
501 [offset_init](u32 i, u64 *__restrict part_ids) {
502 part_ids[i] = i + offset_init;
503 });
504
505 tmp.get_field<u64>(tmp.pdl().get_field_idx<u64>("part_id"))
506 .overwrite(part_ids, loc_inj);
507 }
508 }
509
510 to_insert.insert_elements(tmp);
511
512 u64 sum_push = shamalgs::collective::allreduce_sum<u64>(tmp.get_obj_cnt());
513 u64 sum_all = shamalgs::collective::allreduce_sum<u64>(to_insert.get_obj_cnt());
514
515 u64 min_rank = shamalgs::collective::allreduce_min<u64>(to_insert.get_obj_cnt());
516 u64 max_rank = shamalgs::collective::allreduce_max<u64>(to_insert.get_obj_cnt());
517
518 timer_gen.end();
519
520 if (shamcomm::world_rank() == 0) {
521 f64 part_per_sec = f64(sum_push) / f64(timer_gen.elasped_sec());
522 logger::normal_ln(
523 "SPH setup",
524 shambase::format(
525 "Nstep = {} ( {:.1e} ) Ntotal = {} ( {:.1e} rank min = {:.1e} max = {:.1e}) "
526 "rate = {:e} N.s^-1",
527 sum_push,
528 f64(sum_push),
529 sum_all,
530 f64(sum_all),
531 part_per_sec,
532 f64(min_rank),
533 f64(max_rank)));
534 }
535
536 if (setup_log) {
537 setup_log.value().update_count_per_rank(to_insert.get_obj_cnt());
538 }
539
540 injected_parts += sum_push;
541 }
542
543 time_part_gen.end();
544 if (shamcomm::world_rank() == 0) {
545 logger::normal_ln(
546 "SPH setup", "the generation step took :", time_part_gen.elasped_sec(), "s");
547 }
548
549 if (shamcomm::world_rank() == 0) {
550 logger::normal_ln(
551 "SPH setup", "final particle count =", injected_parts, "beginning injection ...");
552 }
553
555 f64 mpi_timer_start = shamcomm::mpi::get_timer("total");
556
557 // injection part (holy shit this is hard)
558
559 shambase::Timer time_part_inject;
560 time_part_inject.start();
561
562 auto log_inject_status = [&](std::string log_suffix = "") {
563 u64 sum_all = shamalgs::collective::allreduce_sum<u64>(to_insert.get_obj_cnt());
564
565 u32 rank_without_patch
566 = shamalgs::collective::allreduce_sum<u32>(sched.patch_list.local.size() == 0 ? 1 : 0);
567
568 if (shamcomm::world_rank() == 0) {
569 logger::normal_ln(
570 "SPH setup",
571 shambase::format(
572 "injected {:12} / {:} => {:5.1f}% | ranks with patchs = {:d} / {:d} {}",
573 injected_parts - sum_all,
574 injected_parts,
575 f64(injected_parts - sum_all) / f64(injected_parts) * 100.0,
576 shamcomm::world_size() - rank_without_patch,
578 log_suffix));
579 }
580
581 if (setup_log) {
582 setup_log.value().update_count_per_rank(to_insert.get_obj_cnt());
583 }
584 };
585
586 auto inject_in_local_domains =
587 [&sched, &inserter, &compute_load, &insert_step, &log_inject_status](
590
591 bool has_been_limited = true;
592
593 auto dev_sched = shamsys::instance::get_compute_scheduler_ptr();
594 sham::DeviceBuffer<u32> mask_get_ids_where(0, dev_sched);
595
596 while (has_been_limited) {
597 has_been_limited = false;
598 using namespace shamrock::patch;
599
600 // inject in local domains first
601 PatchCoordTransform<Tvec> ptransf = sched.get_sim_box().get_patch_transform<Tvec>();
602 sched.for_each_local_patchdata([&](const Patch &p, PatchDataLayer &pdat) {
603 shammath::CoordRange<Tvec> patch_coord = ptransf.to_obj_coord(p);
604
605 PatchDataField<Tvec> &xyz = to_insert.get_field<Tvec>(0);
606
607 auto ids = xyz.get_ids_where_recycle_buffer(
608 mask_get_ids_where,
609 [](auto access, u32 id, shammath::CoordRange<Tvec> patch_coord) {
610 Tvec tmp = access[id];
611 return patch_coord.contain_pos(tmp);
612 },
613 patch_coord);
614
615 if (ids.get_size() > insert_step) {
616 ids.resize(insert_step);
617 has_been_limited = true;
618 }
619
620 if (ids.get_size() > 0) {
621 to_insert.extract_elements(ids, pdat);
622 }
623 });
624
625 sched.check_patchdata_locality_correctness();
626
627 inserter.balance_load(compute_load);
628
629 has_been_limited
630 = !shamalgs::collective::are_all_rank_true(!has_been_limited, MPI_COMM_WORLD);
631
632 if (has_been_limited) {
633 // since we will restart this one let's print
634 log_inject_status(" -> local loop <-");
635 }
636 }
637 };
638
639 auto get_index_per_ranks = [&](f64 &timer_result) {
641
642 shambase::Timer time_get_index_per_ranks;
643 time_get_index_per_ranks.start();
644
646 sptree.attach_buf();
647
648 // find where each particle should be inserted
649 PatchDataField<Tvec> &pos_field = to_insert.get_field<Tvec>(0);
650
651 if (pos_field.get_nvar() != 1) {
653 }
654
655 sycl::buffer<u64> new_id_buf = sptree.compute_patch_owner(
656 shamsys::instance::get_compute_scheduler_ptr(),
657 pos_field.get_buf(),
658 pos_field.get_obj_cnt());
659
660 std::unordered_map<i32, std::vector<u32>> index_per_ranks;
661 bool err_id_in_newid = false;
662 {
663 sycl::host_accessor nid{new_id_buf, sycl::read_only};
664 for (u32 i = 0; i < pos_field.get_obj_cnt(); i++) {
665 u64 patch_id = nid[i];
666 bool err = patch_id == u64_max;
667 err_id_in_newid = err_id_in_newid || (err);
668
669 i32 rank = sched.get_patch_rank_owner(patch_id);
670 index_per_ranks[rank].push_back(i);
671 }
672 }
673
674 if (err_id_in_newid) {
676 "a new id could not be computed");
677 }
678
679 time_get_index_per_ranks.end();
680 timer_result = time_get_index_per_ranks.elasped_sec();
681
682 return index_per_ranks;
683 };
684
685 f64 total_time_rank_getter = 0;
686 f64 max_time_rank_getter = 0;
687
689 u32 step_count = 0;
690 while (!shamalgs::collective::are_all_rank_true(to_insert.is_empty(), MPI_COMM_WORLD)) {
691
692 // assume that the sched is synchronized and that there is at least a patch.
693 // TODO actually check that
694
695 using namespace shamrock::patch;
696
697 auto dev_sched = shamsys::instance::get_compute_scheduler_ptr();
698
699 inject_in_local_domains(to_insert);
700
701 f64 timer_get_index_per_ranks = 0;
702 std::unordered_map<i32, std::vector<u32>> index_per_ranks
703 = get_index_per_ranks(timer_get_index_per_ranks);
704 total_time_rank_getter += timer_get_index_per_ranks;
705 max_time_rank_getter = std::max(max_time_rank_getter, timer_get_index_per_ranks);
706
707 // allgather the list of messages
708 // format:(u32_2(sender_rank, receiver_rank), u64(indices_size))
709 std::vector<u64> send_msg;
710 for (auto &[rank, indices] : index_per_ranks) {
711 send_msg.push_back(sham::pack32(shamcomm::world_rank(), rank));
712 send_msg.push_back(indices.size());
713 }
714
715 u64 max_send = (1 << 24) / shamcomm::world_size();
716 bool sync_limited = false;
717 if (send_msg.size() > max_send) {
718
719 // here we must pack the send_msg infos in structs in order to keep
720 // them together during shuffle
721
722 struct tmp {
723 u64 ranks, size;
724 };
725
726 // build the vector of structs
727 std::vector<tmp> tmp_vec;
728 tmp_vec.reserve(send_msg.size() / 2);
729 for (u64 i = 0; i < send_msg.size(); i += 2) {
730 tmp_vec.push_back({send_msg[i], send_msg[i + 1]});
731 }
732
733 // shuffle the messages infos
734 u64 local_seed = u64(golden_number * 1000 * step_count + shamcomm::world_rank());
735 std::mt19937_64 eng_local_msg(local_seed);
736 std::shuffle(tmp_vec.begin(), tmp_vec.end(), eng_local_msg);
737
738 // build the new send_msg
739 std::vector<u64> send_msg_new;
740 send_msg_new.reserve(max_send);
741 for (auto &t : tmp_vec) {
742 if (send_msg_new.size() >= max_send) {
743 break;
744 }
745 send_msg_new.push_back(t.ranks);
746 send_msg_new.push_back(t.size);
747 }
748
749 send_msg = send_msg_new;
750 sync_limited = true;
751 }
752
753 std::vector<u64> recv_msg;
754 shamalgs::collective::vector_allgatherv(send_msg, recv_msg, MPI_COMM_WORLD);
755
756 std::vector<std::tuple<u32, u32, u64>> msg_list;
757 for (u64 i = 0; i < recv_msg.size(); i += 2) {
758 u32_2 sender_receiver = sham::unpack32(recv_msg[i]);
759 u64 indices_size = recv_msg[i + 1];
760
761 u32 sender_rank = sender_receiver.x();
762 u32 receiver_rank = sender_receiver.y();
763
764 if (sender_rank == receiver_rank) {
765 continue; // only mean that it was not fully inserted in the patch
766 }
767
768 msg_list.push_back(std::make_tuple(sender_rank, receiver_rank, indices_size));
769 }
770
771 if (setup_log) {
772 setup_log.value().update_msg_list(msg_list);
773 }
774
775 // shuffle msg_list according to seed golden_number*1000*step_count
776 std::mt19937 eng_global_msg(u64(golden_number * 1000 * step_count));
777 std::shuffle(msg_list.begin(), msg_list.end(), eng_global_msg);
778
779 // now that we are in sync we can determine who should send to who
780
781 std::vector<u64> msg_count_rank(shamcomm::world_size());
782 std::vector<u64> comm_size_rank(shamcomm::world_size());
783
784 std::vector<std::tuple<u32, u32, u64>> rank_msg_list;
785
786 bool was_count_limited = false;
787 bool was_size_limited = false;
788 bool was_msg_size_limited = false;
789
790 for (auto &[sender_rank, receiver_rank, indices_size] : msg_list) {
791
792 bool msg_count_limit_not_reached = msg_count_rank.at(receiver_rank) < msg_limit
793 && msg_count_rank.at(sender_rank) < msg_limit;
794
795 bool recv_size_limit_not_reached = comm_size_rank.at(receiver_rank) < data_count_limit
796 && comm_size_rank.at(sender_rank) < data_count_limit;
797
798 was_count_limited = was_count_limited || !msg_count_limit_not_reached;
799 was_size_limited = was_size_limited || !recv_size_limit_not_reached;
800
801 bool can_send_recv = msg_count_limit_not_reached && recv_size_limit_not_reached;
802
803 u64 msg_size = std::min(indices_size, max_message_size);
804 msg_size = std::min(msg_size, data_count_limit);
805 was_msg_size_limited = was_msg_size_limited || (msg_size < indices_size);
806
807 if (can_send_recv) {
808 if (sender_rank == shamcomm::world_rank()
809 || receiver_rank == shamcomm::world_rank()) {
810 if (msg_size > 0) {
811 rank_msg_list.push_back(
812 std::make_tuple(sender_rank, receiver_rank, msg_size));
813 }
814 }
815 }
816
817 msg_count_rank.at(receiver_rank) += 1;
818 msg_count_rank.at(sender_rank) += 1;
819 comm_size_rank.at(receiver_rank) += msg_size;
820 comm_size_rank.at(sender_rank) += msg_size;
821 }
822
823 // logger::raw_ln(
824 // shamcomm::world_rank(),
825 // was_count_limited,
826 // was_size_limited,
827 // msg_count_rank,
828 // comm_size_rank);
829
830 // logger::info_ln(
831 // "SPH setup", "rank", shamcomm::world_rank(), "rank_msg_list", rank_msg_list);
832
833 // extract the data
835 sham::DeviceBuffer idx_to_rem = sham::DeviceBuffer<u32>(0, dev_sched);
836 for (auto &[sender_rank, receiver_rank, indices_size] : rank_msg_list) {
837 if (sender_rank == shamcomm::world_rank()) {
838 std::vector<u32> &idx_to_extract = index_per_ranks[receiver_rank];
839 sham::DeviceBuffer _tmp = sham::DeviceBuffer<u32>(idx_to_extract.size(), dev_sched);
840 _tmp.copy_from_stdvec(idx_to_extract);
841
842 if (_tmp.get_size() > indices_size) {
843 _tmp.resize(indices_size);
844 }
845
846 PatchDataLayer _tmp_pdat = PatchDataLayer(sched.get_layout_ptr_old());
847 to_insert.append_subset_to(_tmp, _tmp.get_size(), _tmp_pdat);
848
849 idx_to_rem.append(_tmp);
850
851 send_data.add_obj(sender_rank, receiver_rank, std::move(_tmp_pdat));
852 }
853 }
854
855 to_insert.remove_ids(idx_to_rem, idx_to_rem.get_size());
856
857 // comm the data to the right ranks
859
860 shamalgs::collective::serialize_sparse_comm<PatchDataLayer>(
861 dev_sched,
862 std::move(send_data),
863 recv_dat,
864 [&](u64 id) {
865 return id; // here the ids in the DDshared are the MPI ranks
866 },
867 [&](PatchDataLayer &pdat) {
868 shamalgs::SerializeHelper ser(dev_sched);
869 ser.allocate(pdat.serialize_buf_byte_size());
870 pdat.serialize_buf(ser);
871 return ser.finalize();
872 },
873 [&](sham::DeviceBuffer<u8> &&buf) {
874 // exchange the buffer held by the distrib data and give it to the
875 // serializer
876 shamalgs::SerializeHelper ser(dev_sched, std::forward<sham::DeviceBuffer<u8>>(buf));
877 return PatchDataLayer::deserialize_buf(ser, sched.get_layout_ptr_old());
878 },
879 comm_cache);
880
881 // insert the data into the data to be inserted
882 recv_dat.for_each([&](u64 sender, u64 receiver, PatchDataLayer &pdat) {
883 to_insert.insert_elements(pdat);
884 });
885
886 was_count_limited
887 = !shamalgs::collective::are_all_rank_true(!was_count_limited, MPI_COMM_WORLD);
888 was_size_limited
889 = !shamalgs::collective::are_all_rank_true(!was_size_limited, MPI_COMM_WORLD);
890 was_msg_size_limited
891 = !shamalgs::collective::are_all_rank_true(!was_msg_size_limited, MPI_COMM_WORLD);
892 bool was_sync_limited
893 = !shamalgs::collective::are_all_rank_true(!sync_limited, MPI_COMM_WORLD);
894
895 std::string log_suffix = "";
896 if (was_count_limited) {
897 log_suffix += " (msg count limited)";
898 }
899 if (was_size_limited) {
900 log_suffix += " (total msg size limited)";
901 }
902 if (was_msg_size_limited) {
903 log_suffix += " (msg size limited)";
904 }
905 if (was_sync_limited) {
906 log_suffix += " (sync limited)";
907 }
908 log_suffix += shambase::format(" (msg count : {})", recv_msg.size());
909 log_inject_status(" <- global loop ->" + log_suffix);
910
911 f64 worst_time_get_index_per_ranks
912 = shamalgs::collective::allreduce_max<f64>(timer_get_index_per_ranks);
913
914 step_count++;
915 }
916
917 if (setup_log) {
918 setup_log.value().dump_state();
919 }
920
921 shamcomm::mpi::Barrier(MPI_COMM_WORLD);
922 time_part_inject.end();
923 if (shamcomm::world_rank() == 0) {
924 logger::normal_ln(
925 "SPH setup", "the injection step took :", time_part_inject.elasped_sec(), "s");
926 }
927
929
930 f64 delta_mpi_timer = shamcomm::mpi::get_timer("total") - mpi_timer_start;
931 f64 t_dev_alloc
932 = (mem_perf_infos_end.time_alloc_device - mem_perf_infos_start.time_alloc_device)
933 + (mem_perf_infos_end.time_free_device - mem_perf_infos_start.time_free_device);
934 f64 t_host_alloc = (mem_perf_infos_end.time_alloc_host - mem_perf_infos_start.time_alloc_host)
935 + (mem_perf_infos_end.time_free_host - mem_perf_infos_start.time_free_host);
936
937 { // perf infos
938 std::vector<f64> time_rank_getter_all_ranks
939 = shamalgs::collective::gather(total_time_rank_getter);
940 std::vector<f64> max_time_rank_getter_all_ranks
941 = shamalgs::collective::gather(max_time_rank_getter);
942 std::vector<f64> mpi_timer_all_ranks = shamalgs::collective::gather(delta_mpi_timer);
943 std::vector<f64> alloc_time_device_all_ranks = shamalgs::collective::gather(t_dev_alloc);
944 std::vector<f64> alloc_time_host_all_ranks = shamalgs::collective::gather(t_host_alloc);
945 std::vector<size_t> max_mem_device_all_ranks
946 = shamalgs::collective::gather(mem_perf_infos_end.max_allocated_byte_device);
947 std::vector<size_t> max_mem_host_all_ranks
948 = shamalgs::collective::gather(mem_perf_infos_end.max_allocated_byte_host);
949
950 if (shamcomm::world_rank() == 0) {
951 f64 time_part_inject_sec = time_part_inject.elasped_sec();
952 f64 sum_t = time_part_inject_sec * shamcomm::world_size();
953
954 f64 sum_time_rank_getter = std::accumulate(
955 time_rank_getter_all_ranks.begin(), time_rank_getter_all_ranks.end(), 0.0);
956 f64 max_time_rank_getter = *std::max_element(
957 max_time_rank_getter_all_ranks.begin(), max_time_rank_getter_all_ranks.end());
958 f64 sum_mpi
959 = std::accumulate(mpi_timer_all_ranks.begin(), mpi_timer_all_ranks.end(), 0.0);
960 f64 sum_alloc_device = std::accumulate(
961 alloc_time_device_all_ranks.begin(), alloc_time_device_all_ranks.end(), 0.0);
962 f64 sum_alloc_host = std::accumulate(
963 alloc_time_host_all_ranks.begin(), alloc_time_host_all_ranks.end(), 0.0);
964 size_t sum_mem_device_total = std::accumulate(
965 max_mem_device_all_ranks.begin(), max_mem_device_all_ranks.end(), 0_u64);
966 size_t sum_mem_host_total = std::accumulate(
967 max_mem_host_all_ranks.begin(), max_mem_host_all_ranks.end(), 0_u64);
968
969 static constexpr u32 cols_count = 6;
970
971 using Table = shambase::table;
972
973 Table table(6);
974
975 table.add_double_rule();
976 table.add_data(
977 {"rank", "rank get (sum/max)", "MPI", "alloc d% h%", "mem (max) d", "mem (max) h"},
978 Table::center);
979 table.add_double_rule();
980 for (u32 i = 0; i < shamcomm::world_size(); i++) {
981 table.add_data(
982 {shambase::format("{:<4}", i),
983 shambase::format(
984 "{:.2f}s / {:.2f}s",
985 time_rank_getter_all_ranks[i],
986 max_time_rank_getter_all_ranks[i]),
987 shambase::format("{:.2f}s", mpi_timer_all_ranks[i]),
988 shambase::format(
989 "{:>.1f}% {:<.1f}%",
990 100 * (alloc_time_device_all_ranks[i] / time_part_inject_sec),
991 100 * (alloc_time_host_all_ranks[i] / time_part_inject_sec)),
992 shambase::format("{}", shambase::readable_sizeof(max_mem_device_all_ranks[i])),
993 shambase::format("{}", shambase::readable_sizeof(max_mem_host_all_ranks[i]))},
994 Table::right);
995 }
996 if (shamcomm::world_size() > 1) {
997 table.add_rulled_data({"", "<avg> / <max>", "<avg>", "<avg>", "<sum>", "<sum>"});
998 table.add_data(
999 {"all",
1000 shambase::format(
1001 "{:.2f}s / {:.2f}s",
1002 sum_time_rank_getter / shamcomm::world_size(),
1003 max_time_rank_getter),
1004 shambase::format("{:.2f}s", sum_mpi / shamcomm::world_size()),
1005 shambase::format(
1006 "{:>.1f}% {:<.1f}%",
1007 100 * (sum_alloc_device / sum_t),
1008 100 * (sum_alloc_host / sum_t)),
1009 shambase::format("{}", shambase::readable_sizeof(sum_mem_device_total)),
1010 shambase::format("{}", shambase::readable_sizeof(sum_mem_host_total))},
1011 Table::right);
1012 }
1013 table.add_rule();
1014 logger::info_ln("SPH setup", "injection perf report:" + table.render());
1015 }
1016 }
1017
1018 if (part_reordering) {
1019 modules::ParticleReordering<Tvec, u32, SPHKernel>(context, solver_config, storage)
1020 .reorder_particles();
1021 }
1022
1023 time_setup.end();
1024 if (shamcomm::world_rank() == 0) {
1025 logger::normal_ln("SPH setup", "the setup took :", time_setup.elasped_sec(), "s");
1026 }
1027}
1028
1029template<class Tvec, template<class> class SPHKernel>
1030inline std::shared_ptr<shammodels::sph::modules::ISPHSetupNode> shammodels::sph::modules::
1031 SPHSetup<Tvec, SPHKernel>::make_modifier_warp_disc(
1032 SetupNodePtr parent, Tscal Rwarp, Tscal Hwarp, Tscal inclination, Tscal posangle) {
1033 return std::shared_ptr<ISPHSetupNode>(new ModifierApplyDiscWarp<Tvec, SPHKernel>(
1034 context, solver_config, parent, Rwarp, Hwarp, inclination, posangle));
1035}
1036
1037template<class Tvec, template<class> class SPHKernel>
1038inline std::shared_ptr<shammodels::sph::modules::ISPHSetupNode> shammodels::sph::modules::
1039 SPHSetup<Tvec, SPHKernel>::make_modifier_custom_warp(
1040 SetupNodePtr parent,
1041 std::function<Tscal(Tscal)> inc_profile,
1042 std::function<Tscal(Tscal)> psi_profile,
1043 std::function<Tvec(Tscal)> k_profile) {
1044 return std::shared_ptr<ISPHSetupNode>(new ModifierApplyCustomWarp<Tvec, SPHKernel>(
1045 context, solver_config, parent, inc_profile, psi_profile, k_profile));
1046}
1047
1048template<class Tvec, template<class> class SPHKernel>
1049inline std::shared_ptr<shammodels::sph::modules::ISPHSetupNode> shammodels::sph::modules::
1050 SPHSetup<Tvec, SPHKernel>::make_modifier_add_offset(
1051 SetupNodePtr parent, Tvec offset_postion, Tvec offset_velocity) {
1052
1053 return std::shared_ptr<ISPHSetupNode>(
1054 new ModifierOffset<Tvec>(context, parent, offset_postion, offset_velocity));
1055}
1056
1057template<class Tvec, template<class> class SPHKernel>
1058inline std::shared_ptr<shammodels::sph::modules::ISPHSetupNode> shammodels::sph::modules::SPHSetup<
1059 Tvec,
1060 SPHKernel>::make_modifier_filter(SetupNodePtr parent, std::function<bool(Tvec)> filter) {
1061
1062 return std::shared_ptr<ISPHSetupNode>(
1063 new ModifierFilter<Tvec, SPHKernel>(context, parent, filter));
1064}
1065
1066template<class Tvec, template<class> class SPHKernel>
1067inline std::shared_ptr<shammodels::sph::modules::ISPHSetupNode> shammodels::sph::modules::
1068 SPHSetup<Tvec, SPHKernel>::make_modifier_split_part(
1069 SetupNodePtr parent, u64 n_split, u64 seed, Tscal h_scaling) {
1070 return std::shared_ptr<ISPHSetupNode>(
1071 new ModifierSplitPart<Tvec>(context, parent, n_split, seed, h_scaling));
1072}
1073
1074using namespace shammath;
1078
constexpr const char * xyz
Position field (3D coordinates)
Header file describing a Node Instance.
double f64
Alias for double.
std::uint32_t u32
32 bit unsigned integer
std::uint64_t u64
64 bit unsigned integer
std::int32_t i32
32 bit integer
Collective boolean reduction to check if all ranks have true as input.
bool are_all_rank_true(bool input, MPI_Comm comm)
return true only if all ranks have true as input
The MPI scheduler.
u64 crit_patch_split
splitting limit (if load value > crit_patch_split => patch split)
SchedulerPatchList patch_list
handle the list of the patches of the scheduler
std::vector< shamrock::patch::Patch > local
contain the list of patch owned by the current node
A buffer allocated in USM (Unified Shared Memory)
void copy_from_stdvec(const std::vector< T > &vec)
Copy the content of a std::vector into the buffer.
void resize(size_t new_size, bool keep_data=true)
Resizes the buffer to a given size.
void append(const DeviceBuffer &other)
Append the content of another buffer to this one.
size_t get_size() const
Gets the number of elements in the buffer.
Container for objects shared between two distributed data elements.
void for_each(std::function< void(u64, u64, T &)> &&f)
Apply a function to all stored objects.
iterator add_obj(u64 left_id, u64 right_id, T &&obj)
Add an object associated with a patch pair.
Represents a collection of objects distributed across patches identified by a u64 id.
iterator add_obj(u64 id, T &&obj)
Adds a new object to the collection.
T & get(u64 id)
Returns a reference to an object in the collection.
void reset()
Reset the collection to its initial state.
Class Timer measures the time elapsed since the timer was started.
Definition time.hpp:96
void end()
Stops the timer and stores the elapsed time in nanoseconds.
Definition time.hpp:111
f64 elasped_sec() const
Converts the stored nanosecond time to a floating point representation in seconds.
Definition time.hpp:123
void start()
Starts the timer.
Definition time.hpp:106
Class to insert data in the PatchScheduler.
u32 get_field_idx(const std::string &field_name) const
Get the field id if matching name & type.
PatchDataLayer container class, the layout is described in patchdata_layout.
PatchCoordTransform< T > get_patch_transform() const
Get a PatchCoordTransform object that describes the conversion between patch coordinates and domain c...
Definition SimBox.hpp:285
std::vector< int > vector_allgatherv(const std::vector< T > &send_vec, const MPI_Datatype &send_type, std::vector< T > &recv_vec, const MPI_Datatype &recv_type, const MPI_Comm comm)
allgatherv on vector with size query (size querying variant of vector_allgatherv_ks) //TODO add fault...
Definition exchanges.hpp:98
MemPerfInfos get_mem_perf_info()
Retrieve the memory performance information.
Boolean reduction algorithm for checking if all elements are non-zero.
void kernel_call(sham::DeviceQueue &q, RefIn in, RefOut in_out, u32 n, Functor &&func, SourceLocation &&callsite=SourceLocation{})
Submit a kernel to a SYCL queue.
std::string readable_sizeof(double size)
given a sizeof value return a readble string Example : readable_sizeof(1024*1024*1024) -> "1....
Definition string.hpp:139
void throw_with_loc(std::string message, SourceLocation loc=SourceLocation{})
Throw an exception and append the source location to it.
T & get_check_ref(const std::unique_ptr< T > &ptr, SourceLocation loc=SourceLocation())
Takes a std::unique_ptr and returns a reference to the object it holds. It throws a std::runtime_erro...
Definition memory.hpp:110
void throw_unimplemented(SourceLocation loc=SourceLocation{})
Throw a std::runtime_error saying that the function is unimplemented.
i32 world_rank()
Gives the rank of the current process in the MPI communicator.
Definition worldInfo.cpp:40
i32 world_size()
Gives the size of the MPI communicator.
Definition worldInfo.cpp:38
namespace for math utility
Definition AABB.hpp:26
constexpr u64 u64_max
u64 max value
void err(std::string module_name, Types... var2)
Prints a log message with multiple arguments.
Definition logs.hpp:133
#define __shamrock_stack_entry()
Macro to create a stack entry.
Structure to store the performance informations about memory allocation and deallocation.
f64 time_alloc_host
Time spent allocating memory on the host.
size_t max_allocated_byte_host
max bytes allocated on the host
f64 time_free_device
Time spent deallocating memory on the device.
size_t max_allocated_byte_device
max bytes allocated on the device
f64 time_alloc_device
Time spent allocating memory on the device.
f64 time_free_host
Time spent deallocating memory on the host.
A class that references multiple buffers or similar objects.
Patch object that contain generic patch information.
Definition Patch.hpp:33
Functions related to the MPI communicator.
#define ON_RANK_0(x)
Macro to execute code only on rank 0.
Definition worldInfo.hpp:73
void Exscan(const void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
MPI wrapper for MPI_Exscan.
Definition wrapper.cpp:166
void Barrier(MPI_Comm comm)
MPI wrapper for MPI_Barrier.
Definition wrapper.cpp:194
f64 get_timer(std::string timername)
get a timer value
Definition wrapper.cpp:44
void Allreduce(const void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
MPI wrapper for MPI_Allreduce.
Definition wrapper.cpp:119