Shamrock 2025.10.0
Astrophysical Code
Loading...
Searching...
No Matches
BasicSPHGhosts.cpp
Go to the documentation of this file.
1// -------------------------------------------------------//
2//
3// SHAMROCK code for hydrodynamics
4// Copyright (c) 2021-2026 Timothée David--Cléris <tim.shamrock@proton.me>
5// SPDX-License-Identifier: CeCILL Free Software License Agreement v2.1
6// Shamrock is licensed under the CeCILL 2.1 License, see LICENSE for more information
7//
8// -------------------------------------------------------//
9
17/*
18
19Test code for godbolt
20
21
22#include <iostream>
23#include <vector>
24
25namespace sycl{
26 template<class T>
27 struct vec{
28 T _x,_y,_z;
29
30 inline T & x(){
31 return _x;
32 }
33
34 inline T & y(){
35 return _y;
36 }
37 inline T & z(){
38 return _z;
39 }
40 };
41}
42
43
44using i32 = int;
45using i32_3 = sycl::vec<i32>;
46
47template<class T>
48struct ShiftInfo{
49 sycl::vec<T> shift;
50 sycl::vec<T> shift_speed;
51};
52
53template<class T>
54struct ShearPeriodicInfo{
55 i32_3 shear_base;
56 i32_3 shear_dir;
57 T shear_value;
58 T shear_speed;
59};
60
61template<class T>
62inline ShiftInfo<T> compute_shift_infos(
63 i32_3 ioff, ShearPeriodicInfo<T> shear, sycl::vec<T> bsize
64 ){
65
66 i32 dx = ioff.x()*shear.shear_base.x();
67 i32 dy = ioff.y()*shear.shear_base.y();
68 i32 dz = ioff.z()*shear.shear_base.z();
69
70 i32 d = dx + dy + dz;
71
72 sycl::vec<T> shift = {
73 (d*shear.shear_dir.x())*shear.shear_value + bsize.x()*ioff.x(),
74 (d*shear.shear_dir.y())*shear.shear_value + bsize.y()*ioff.y() ,
75 (d*shear.shear_dir.z())*shear.shear_value + bsize.z()*ioff.z()
76 };
77 sycl::vec<T> shift_speed = {
78 (d*shear.shear_dir.x())*shear.shear_speed,
79 (d*shear.shear_dir.y())*shear.shear_speed,
80 (d*shear.shear_dir.z())*shear.shear_speed
81 };
82
83 return {shift,shift_speed};
84}
85
86template<class T>
87inline void for_each_patch_shift(ShearPeriodicInfo<T> shearinfo, sycl::vec<T> bsize){
88
89 i32_3 loop_offset = {0,0,0};
90
91 std::vector<i32_3> list_possible;
92
93
94 i32 repetition_x = 1;
95 i32 repetition_y = 1;
96 i32 repetition_z = 1;
97
98
99
100 for (i32 xoff = -repetition_x; xoff <= repetition_x; xoff++) {
101 for (i32 yoff = -repetition_y; yoff <= repetition_y; yoff++) {
102 for (i32 zoff = -repetition_z; zoff <= repetition_z; zoff++) {
103
104
105 i32 dx = xoff*shearinfo.shear_base.x();
106 i32 dy = yoff*shearinfo.shear_base.y();
107 i32 dz = zoff*shearinfo.shear_base.z();
108
109 i32 d = dx + dy + dz;
110
111 i32 df = -int(d * shearinfo.shear_value);
112
113 i32_3 off_d = {
114 shearinfo.shear_dir.x()*df,
115 shearinfo.shear_dir.y()*df,
116 shearinfo.shear_dir.z()*df
117 };
118
119 list_possible.push_back({xoff+off_d.x(),yoff+off_d.y(),zoff+off_d.z()});
120 }
121 }
122 }
123
124 for(i32_3 off : list_possible){
125
126 auto shift = compute_shift_infos(off,shearinfo,bsize);
127
128 std::cout <<
129 off.x() << " " << off.y() << " " << off.z() << " | " <<
130 shift.shift.x() << " " << shift.shift.y() << " " << shift.shift.z() << " "<<std::endl;
131 }
132
133
134
135}
136
137
138int main(){
139
140 ShearPeriodicInfo<float> shear{
141 {1,0,0},
142 {0,0,1},
143 13.5,
144 1
145 };
146
147 for_each_patch_shift(shear, {1,1,1});
148
149}
150
151
152*/
153
154#include "shambase/exception.hpp"
155#include "shambase/time.hpp"
158#include "shamcomm/worldInfo.hpp"
160#include <functional>
161#include <vector>
162
163template<class T>
164struct ShiftInfo {
165 sycl::vec<T, 3> shift;
166 sycl::vec<T, 3> shift_speed;
167};
168
169template<class T>
170using ShearPeriodicInfo =
172
173template<class T>
174inline ShiftInfo<T> compute_shift_infos(
175 i32_3 ioff, ShearPeriodicInfo<T> shear, sycl::vec<T, 3> bsize) {
176
177 i32 dx = ioff.x() * shear.shear_base.x();
178 i32 dy = ioff.y() * shear.shear_base.y();
179 i32 dz = ioff.z() * shear.shear_base.z();
180
181 i32 d = dx + dy + dz;
182
183 sycl::vec<T, 3> shift
184 = {(d * shear.shear_dir.x()) * shear.shear_value + bsize.x() * ioff.x(),
185 (d * shear.shear_dir.y()) * shear.shear_value + bsize.y() * ioff.y(),
186 (d * shear.shear_dir.z()) * shear.shear_value + bsize.z() * ioff.z()};
187 sycl::vec<T, 3> shift_speed
188 = {(d * shear.shear_dir.x()) * shear.shear_speed,
189 (d * shear.shear_dir.y()) * shear.shear_speed,
190 (d * shear.shear_dir.z()) * shear.shear_speed};
191
192 return {shift, shift_speed};
193}
194
195template<class T>
196inline void for_each_patch_shift(
197 ShearPeriodicInfo<T> shearinfo,
198 sycl::vec<T, 3> bsize,
199 std::function<void(i32_3, ShiftInfo<T>)> funct) {
200
201 i32_3 loop_offset = {0, 0, 0};
202
203 std::vector<i32_3> list_possible;
204
205 // logger::raw_ln("testing :",shearinfo.shear_value,shearinfo.shear_dir, shearinfo.shear_base);
206
207 // a bit of dirty fix doesn't hurt
208 // this should be done in a better way a some point
209 i32 repetition_x = 1 + abs(shearinfo.shear_dir.x());
210 i32 repetition_y = 1 + abs(shearinfo.shear_dir.y());
211 i32 repetition_z = 1 + abs(shearinfo.shear_dir.z());
212
213 T sz = bsize.x() * shearinfo.shear_dir.x() + bsize.y() * shearinfo.shear_dir.y()
214 + bsize.z() * shearinfo.shear_dir.z();
215
216 for (i32 xoff = -repetition_x; xoff <= repetition_x; xoff++) {
217 for (i32 yoff = -repetition_y; yoff <= repetition_y; yoff++) {
218 for (i32 zoff = -repetition_z; zoff <= repetition_z; zoff++) {
219
220 i32 dx = xoff * shearinfo.shear_base.x();
221 i32 dy = yoff * shearinfo.shear_base.y();
222 i32 dz = zoff * shearinfo.shear_base.z();
223
224 i32 d = dx + dy + dz;
225
226 i32 df = -int(d * shearinfo.shear_value / sz);
227
228 i32_3 off_d
229 = {shearinfo.shear_dir.x() * df,
230 shearinfo.shear_dir.y() * df,
231 shearinfo.shear_dir.z() * df};
232
233 // on redhat based systems stl vector freaks out
234 // because iterator to back does *(end() - 1)
235 // the issue is that the compiler gets confused
236 // by the sycl::vec defining the - operator
237 // creating the ambiguity and ...
238 // ultimatly the compiler shitting itself
239 list_possible.resize(list_possible.size() + 1);
240 list_possible[list_possible.size() - 1]
241 = i32_3{xoff + off_d.x(), yoff + off_d.y(), zoff + off_d.z()};
242 }
243 }
244 }
245
246 // logger::raw_ln("trying", list_possible.size(), "patches ghosts");
247
248 for (i32_3 off : list_possible) {
249
250 auto shift = compute_shift_infos(off, shearinfo, bsize);
251
252 // logger::raw_ln("check :",off,shift.shift, shift.shift_speed);
253
254 funct(off, shift);
255 }
256}
257
258using namespace shammodels::sph;
259
260template<class vec>
262 SerialPatchTree<vec> &sptree,
263 shamrock::patch::PatchtreeField<flt> &int_range_max_tree,
265
266 StackEntry stack_loc{};
267
268 using namespace shamrock::patch;
269 using namespace shammath;
270
271 i32 repetition_x = 1;
272 i32 repetition_y = 1;
273 i32 repetition_z = 1;
274
275 shamrock::patch::SimulationBoxInfo &sim_box = sched.get_sim_box();
276
277 PatchCoordTransform<vec> patch_coord_transf = sim_box.get_patch_transform<vec>();
278 vec bsize = sim_box.get_bounding_box_size<vec>();
279
280 GeneratorMap interf_map;
281
283 using BCConfig = typename CfgClass::Variant;
284
285 using BCFree = typename CfgClass::Free;
286 using BCPeriodic = typename CfgClass::Periodic;
287 using BCShearingPeriodic = typename CfgClass::ShearingPeriodic;
288
289 shambase::Timer base_timer;
290 base_timer.start();
291
292 if (BCPeriodic *cfg = std::get_if<BCPeriodic>(&ghost_config)) {
293 sycl::host_accessor acc_tf{
294 shambase::get_check_ref(int_range_max_tree.internal_buf), sycl::read_only};
295
296 for (i32 xoff = -repetition_x; xoff <= repetition_x; xoff++) {
297 for (i32 yoff = -repetition_y; yoff <= repetition_y; yoff++) {
298 for (i32 zoff = -repetition_z; zoff <= repetition_z; zoff++) {
299
300 // sender translation
301 vec periodic_offset = vec{xoff * bsize.x(), yoff * bsize.y(), zoff * bsize.z()};
302
303 sycl::host_accessor tree{
304 shambase::get_check_ref(sptree.serial_tree_buf), sycl::read_only};
305 sycl::host_accessor lpid{
306 shambase::get_check_ref(sptree.linked_patch_ids_buf), sycl::read_only};
307
308#pragma omp parallel for
309 for (u32 i = 0; i < sched.patch_list.local.size(); i++) {
310 const shamrock::patch::Patch &psender = sched.patch_list.local[i];
311 if (!psender.is_err_mode()) {
312 CoordRange<vec> sender_bsize = patch_coord_transf.to_obj_coord(psender);
313 CoordRange<vec> sender_bsize_off
314 = sender_bsize.add_offset(periodic_offset);
315
316 flt sender_volume = sender_bsize.get_volume();
317
318 flt sender_h_max = int_range_max.get(psender.id_patch);
319
320 using PtNode = typename SerialPatchTree<vec>::PtNode;
321
322 sptree.host_for_each_leafs_internal(
323 [&](u64 tree_id, PtNode n) {
324 flt receiv_h_max = acc_tf[tree_id];
325 CoordRange<vec> receiv_exp{
326 n.box_min - receiv_h_max, n.box_max + receiv_h_max};
327
328 return receiv_exp.get_intersect(sender_bsize_off)
329 .is_not_empty();
330 },
331 [&](u64 id_found, PtNode n) {
332 if ((id_found == psender.id_patch) && (xoff == 0) && (yoff == 0)
333 && (zoff == 0)) {
334 return;
335 }
336
337 CoordRange<vec> receiv_exp
338 = CoordRange<vec>{n.box_min, n.box_max}.expand_all(
339 int_range_max.get(id_found));
340
341 CoordRange<vec> interf_volume = sender_bsize.get_intersect(
342 receiv_exp.add_offset(-periodic_offset));
343
344#pragma omp critical
345 interf_map.add_obj(
346 psender.id_patch,
347 id_found,
348 {periodic_offset,
349 {0, 0, 0},
350 {xoff, yoff, zoff},
351 interf_volume,
352 interf_volume.get_volume() / sender_volume});
353 },
354 tree,
355 lpid);
356 }
357 }
358 }
359 }
360 }
361 } else if (BCShearingPeriodic *cfg = std::get_if<BCShearingPeriodic>(&ghost_config)) {
362 sycl::host_accessor acc_tf{
363 shambase::get_check_ref(int_range_max_tree.internal_buf), sycl::read_only};
364
365 for_each_patch_shift<flt>(*cfg, bsize, [&](i32_3 ioff, ShiftInfo<flt> shift) {
366 i32 xoff = ioff.x();
367 i32 yoff = ioff.y();
368 i32 zoff = ioff.z();
369
370 vec offset = shift.shift;
371
372 sycl::host_accessor tree{
373 shambase::get_check_ref(sptree.serial_tree_buf), sycl::read_only};
374 sycl::host_accessor lpid{
375 shambase::get_check_ref(sptree.linked_patch_ids_buf), sycl::read_only};
376
377#pragma omp parallel for
378 for (u32 i = 0; i < sched.patch_list.local.size(); i++) {
379 const shamrock::patch::Patch &psender = sched.patch_list.local[i];
380 if (!psender.is_err_mode()) {
381
382 CoordRange<vec> sender_bsize = patch_coord_transf.to_obj_coord(psender);
383 CoordRange<vec> sender_bsize_off = sender_bsize.add_offset(offset);
384
385 flt sender_volume = sender_bsize.get_volume();
386
387 flt sender_h_max = int_range_max.get(psender.id_patch);
388
389 using PtNode = typename SerialPatchTree<vec>::PtNode;
390
391 sptree.host_for_each_leafs_internal(
392 [&](u64 tree_id, PtNode n) {
393 flt receiv_h_max = acc_tf[tree_id];
394 CoordRange<vec> receiv_exp{
395 n.box_min - receiv_h_max, n.box_max + receiv_h_max};
396
397 return receiv_exp.get_intersect(sender_bsize_off).is_not_empty();
398 },
399 [&](u64 id_found, PtNode n) {
400 if ((id_found == psender.id_patch) && (xoff == 0) && (yoff == 0)
401 && (zoff == 0)) {
402 return;
403 }
404
405 CoordRange<vec> receiv_exp
406 = CoordRange<vec>{n.box_min, n.box_max}.expand_all(
407 int_range_max.get(id_found));
408
409 CoordRange<vec> interf_volume
410 = sender_bsize.get_intersect(receiv_exp.add_offset(-offset));
411
412#pragma omp critical
413 interf_map.add_obj(
414 psender.id_patch,
415 id_found,
416 {offset,
417 shift.shift_speed,
418 {xoff, yoff, zoff},
419 interf_volume,
420 interf_volume.get_volume() / sender_volume});
421
422 // logger::raw_ln("found :",offset, shift.shift_speed, vec{xoff, yoff,
423 // zoff});
424 },
425 tree,
426 lpid);
427 }
428 }
429 });
430
431 } else {
432 sycl::host_accessor acc_tf{
433 shambase::get_check_ref(int_range_max_tree.internal_buf), sycl::read_only};
434 // sender translation
435 vec periodic_offset = vec{0, 0, 0};
436
437 sycl::host_accessor tree{shambase::get_check_ref(sptree.serial_tree_buf), sycl::read_only};
438 sycl::host_accessor lpid{
439 shambase::get_check_ref(sptree.linked_patch_ids_buf), sycl::read_only};
440
441#pragma omp parallel for
442 for (u32 i = 0; i < sched.patch_list.local.size(); i++) {
443 const shamrock::patch::Patch &psender = sched.patch_list.local[i];
444 if (!psender.is_err_mode()) {
445 CoordRange<vec> sender_bsize = patch_coord_transf.to_obj_coord(psender);
446 CoordRange<vec> sender_bsize_off = sender_bsize.add_offset(periodic_offset);
447
448 flt sender_volume = sender_bsize.get_volume();
449
450 flt sender_h_max = int_range_max.get(psender.id_patch);
451
452 using PtNode = typename SerialPatchTree<vec>::PtNode;
453
454 sptree.host_for_each_leafs_internal(
455 [&](u64 tree_id, PtNode n) {
456 flt receiv_h_max = acc_tf[tree_id];
457 CoordRange<vec> receiv_exp{
458 n.box_min - receiv_h_max, n.box_max + receiv_h_max};
459
460 return receiv_exp.get_intersect(sender_bsize_off).is_not_empty();
461 },
462 [&](u64 id_found, PtNode n) {
463 if (id_found == psender.id_patch) {
464 return;
465 }
466
467 CoordRange<vec> receiv_exp
468 = CoordRange<vec>{n.box_min, n.box_max}.expand_all(
469 int_range_max.get(id_found));
470
471 CoordRange<vec> interf_volume
472 = sender_bsize.get_intersect(receiv_exp.add_offset(-periodic_offset));
473
474#pragma omp critical
475 interf_map.add_obj(
476 psender.id_patch,
477 id_found,
478 {periodic_offset,
479 {0, 0, 0},
480 {0, 0, 0},
481 interf_volume,
482 interf_volume.get_volume() / sender_volume});
483 },
484 tree,
485 lpid);
486 }
487 }
488 }
489
490 base_timer.end();
491
492 // f64 worse_time = shamalgs::collective::allreduce_max(base_timer.elasped_sec());
493 // if (shamcomm::world_rank() == 0) {
494 // shamlog_info_ln(
495 // "BasicSPHGhosts",
496 // "find_interfaces time:",
497 // base_timer.get_time_str(),
498 // "worse time:",
499 // worse_time);
500 // }
501
502 // interf_map.for_each([](u64 sender, u64 receiver, InterfaceBuildInfos build){
503 // logger::raw_ln("found interface
504 // :",sender,"->",receiver,"ratio:",build.volume_ratio,
505 // "volume:",build.cut_volume.lower,build.cut_volume.upper);
506 // });
507
508 return interf_map;
509}
510
511template<class vec>
514 StackEntry stack_loc{};
515 using namespace shamrock::patch;
516
518
519 std::map<u64, f64> send_count_stats;
520
521 gen.for_each([&](u64 sender, u64 receiver, InterfaceBuildInfos &build) {
522 shamrock::patch::PatchDataLayer &src = sched.patch_data.get_pdat(sender);
523 PatchDataField<vec> &xyz = src.get_field<vec>(0);
524
525 sham::DeviceBuffer<u32> idxs_res = xyz.get_ids_where(
526 [](auto access, u32 id, vec vmin, vec vmax) {
527 return Patch::is_in_patch_converted(access[id], vmin, vmax);
528 },
529 build.cut_volume.lower,
530 build.cut_volume.upper);
531
532 u32 pcnt = idxs_res.get_size();
533
534 // prevent sending empty patches
535 if (pcnt == 0) {
536 return;
537 }
538
539 f64 ratio = f64(pcnt) / f64(src.get_obj_cnt());
540
541 shamlog_debug_ln(
542 "InterfaceGen",
543 "gen interface :",
544 sender,
545 "->",
546 receiver,
547 "volume ratio:",
548 build.volume_ratio,
549 "part_ratio:",
550 ratio);
551
552 res.add_obj(sender, receiver, InterfaceIdTable{build, std::move(idxs_res), ratio});
553
554 send_count_stats[sender] += ratio;
555 });
556
557 bool has_warn = false;
558
559 std::string warn_log = "";
560
561 for (auto &[k, v] : send_count_stats) {
562 if (v > 0.2) {
563 warn_log += shambase::format("\n patch {} high interf/patch volume: {}", k, v);
564 has_warn = true;
565 }
566 }
567
568 if (has_warn && shamcomm::world_rank() == 0) {
569 warn_log = "\n This can lead to high mpi "
570 "overhead, try to increase the patch split crit"
571 + warn_log;
572 }
573
574 if (has_warn) {
575 logger::warn_ln("InterfaceGen", "High interface/patch volume ratio." + warn_log);
576 }
577
578 return res;
579}
580
581template<class vec>
584 StackEntry stack_loc{};
585
586 static u32 cnt_dump_debug = 0;
587
588 std::string loc_graph = "";
589 interf_info.for_each([&loc_graph](u64 send, u64 recv, InterfaceIdTable &info) {
590 loc_graph += shambase::format(" p{} -> p{}\n", send, recv);
591 });
592
593 sched.for_each_patch_data(
595 if (pdat.get_obj_cnt() > 0) {
596 loc_graph += shambase::format(
597 " p{} [label= \"id={} N={}\"]\n", id, id, pdat.get_obj_cnt());
598 }
599 });
600
601 std::string dot_graph = "";
602 shamalgs::collective::gather_str(loc_graph, dot_graph);
603
604 dot_graph = "strict digraph {\n" + dot_graph + "}";
605
606 if (shamcomm::world_rank() == 0) {
607 std::string fname = shambase::format("ghost_graph_{}.dot", cnt_dump_debug);
608 logger::info_ln("SPH Ghost", "writing", fname);
609 shambase::write_string_to_file(fname, dot_graph);
610 cnt_dump_debug++;
611 }
612}
613
double f64
Alias for double.
std::uint32_t u32
32 bit unsigned integer
std::uint64_t u64
64 bit unsigned integer
std::int32_t i32
32 bit integer
A buffer allocated in USM (Unified Shared Memory)
void for_each(std::function< void(u64, u64, T &)> &&f)
Apply a function to all stored objects.
iterator add_obj(u64 left_id, u64 right_id, T &&obj)
Add an object associated with a patch pair.
Class Timer measures the time elapsed since the timer was started.
Definition time.hpp:96
void start()
Starts the timer.
Definition time.hpp:106
Vector class based on std::array storage and mdspan.
Definition matrix.hpp:96
shambase::DistributedDataShared< InterfaceIdTable > gen_id_table_interfaces(GeneratorMap &&gen)
precompute interfaces members and cache result in the return
GeneratorMap find_interfaces(SerialPatchTree< vec > &sptree, shamrock::patch::PatchtreeField< flt > &int_range_max_tree, shamrock::patch::PatchField< flt > &int_range_max)
Find interfaces and their metadata.
PatchDataLayer container class, the layout is described in patchdata_layout.
Store the information related to the size of the simulation box to convert patch integer coordinates ...
Definition SimBox.hpp:35
T get_bounding_box_size() const
Get the size of the stored bounding box of the domain.
Definition SimBox.hpp:87
PatchCoordTransform< T > get_patch_transform() const
Get a PatchCoordTransform object that describes the conversion between patch coordinates and domain c...
Definition SimBox.hpp:285
This header file contains utility functions related to exception handling in the code.
MPI string gather / allgather helpers (declarations; implementations in shamalgs/src/collective/gathe...
void gather_str(const std::string &send_vec, std::string &recv_vec)
Gathers a string from all nodes and store the result in a std::string.
void write_string_to_file(std::string filename, std::string s)
dump a string to a file
Definition string.hpp:168
T & get_check_ref(const std::unique_ptr< T > &ptr, SourceLocation loc=SourceLocation())
Takes a std::unique_ptr and returns a reference to the object it holds. It throws a std::runtime_erro...
Definition memory.hpp:110
i32 world_rank()
Gives the rank of the current process in the MPI communicator.
Definition worldInfo.cpp:40
namespace for math utility
Definition AABB.hpp:26
namespace for the sph model
Boundary conditions configuration.
Definition BCConfig.hpp:40
Patch object that contain generic patch information.
Definition Patch.hpp:33
bool is_err_mode() const
check if a patch is in error mode
Definition Patch.hpp:119
u64 id_patch
unique key that identify the patch
Definition Patch.hpp:86
Functions related to the MPI communicator.