Shamrock 2025.10.0
Astrophysical Code
Loading...
Searching...
No Matches
Solver.cpp
Go to the documentation of this file.
1// -------------------------------------------------------//
2//
3// SHAMROCK code for hydrodynamics
4// Copyright (c) 2021-2026 Timothée David--Cléris <tim.shamrock@proton.me>
5// SPDX-License-Identifier: CeCILL Free Software License Agreement v2.1
6// Shamrock is licensed under the CeCILL 2.1 License, see LICENSE for more information
7//
8// -------------------------------------------------------//
9
16
17#include "shambase/assert.hpp"
20#include "shambase/memory.hpp"
22#include "shambase/string.hpp"
23#include "shambase/tabulate.hpp"
24#include "shambase/time.hpp"
32#include "shambackends/math.hpp"
33#include "shamcomm/logs.hpp"
35#include "shamcomm/wrapper.hpp"
76#include "shamphys/mhd.hpp"
103#include "shamsys/legacy/log.hpp"
107#include <memory>
108#include <stdexcept>
109#include <vector>
110
111template<class Tvec, template<class> class Kern>
113
114 shamrock::patch::PatchDataLayerLayout &pdl = scheduler().pdl_old();
115 bool has_B_field = solver_config.has_field_B_on_rho();
116 bool has_psi_field = solver_config.has_field_psi_on_ch();
117 bool has_epsilon_field = solver_config.dust_config.has_epsilon_field();
118 bool has_deltav_field = solver_config.dust_config.has_deltav_field();
119 bool has_s_j_field = solver_config.dust_config.has_s_j_field();
120
121 using namespace shamrock::solvergraph;
122
123 SolverGraph &solver_graph = storage.solver_graph;
124
125 solver_graph.register_edge(
126 "scheduler_patchdata", PatchDataLayerRefs("patchdatas", "\\mathbb{U}_{\\rm patch}"));
127 solver_graph.register_edge("part_counts", Indexes<u32>("Npart_patch", "N_{\\rm part}_p"));
128
129 solver_graph.register_edge("dt", IDataEdge<Tscal>("dt", "dt"));
130 solver_graph.register_edge("dt_half", IDataEdge<Tscal>("dt_half", "\\frac{dt}{2}"));
131 solver_graph.register_edge("gpart_mass", ScalarEdge<Tscal>("m", "m"));
132
133 solver_graph.register_edge("xyz", FieldRefs<Tvec>("xyz", "\\mathbf{r}"));
134 solver_graph.register_edge("vxyz", FieldRefs<Tvec>("vxyz", "\\mathbf{v}"));
135 solver_graph.register_edge("axyz", FieldRefs<Tvec>("axyz", "\\mathbf{a}"));
136 solver_graph.register_edge("uint", FieldRefs<Tscal>("uint", "u_{\\rm int}"));
137 solver_graph.register_edge("duint", FieldRefs<Tscal>("duint", "du_{\\rm int}"));
138 solver_graph.register_edge("hpart", FieldRefs<Tscal>("hpart", "h_{\\rm part}"));
139
140 if (has_B_field) {
141 solver_graph.register_edge("B/rho", FieldRefs<Tvec>("B/rho", "B_{\\rho}"));
142 solver_graph.register_edge("dB/rho", FieldRefs<Tvec>("dB/rho", "dB_{\\rho}"));
143 }
144 if (has_psi_field) {
145 solver_graph.register_edge("psi/ch", FieldRefs<Tscal>("psi/ch", "\\psi_{\\rm ch}"));
146 solver_graph.register_edge("dpsi/ch", FieldRefs<Tscal>("dpsi/ch", "d\\psi_{\\rm ch}"));
147 }
148 if (has_epsilon_field) {
149 solver_graph.register_edge("epsilon", FieldRefs<Tscal>("epsilon", "\\epsilon"));
150 solver_graph.register_edge("dtepsilon", FieldRefs<Tscal>("dtepsilon", "d\\epsilon"));
151 }
152 if (has_deltav_field) {
153 solver_graph.register_edge("deltav", FieldRefs<Tvec>("deltav", "\\Delta v"));
154 solver_graph.register_edge("dtdeltav", FieldRefs<Tvec>("dtdeltav", "d\\Delta v"));
155 }
156 if (has_s_j_field) {
157 solver_graph.register_edge("s_j", FieldRefs<Tscal>("s_j", "S_j"));
158 solver_graph.register_edge("ds_j_dt", FieldRefs<Tscal>("ds_j_dt", "dS_j/dt"));
159 }
160
161 {
162 auto set_gpart_mass = solver_graph.register_node(
163 "set_gpart_mass", NodeSetEdge<ScalarEdge<Tscal>>([&](ScalarEdge<Tscal> &gpart_mass) {
164 gpart_mass.value = solver_config.gpart_mass;
165 }));
166 shambase::get_check_ref(set_gpart_mass)
167 .set_edges(solver_graph.get_edge_ptr<ScalarEdge<Tscal>>("gpart_mass"));
168 }
169
171 // attach fields to scheduler
173 {
174 std::vector<std::shared_ptr<shamrock::solvergraph::INode>> attach_field_sequence;
175
176 {
177 auto set_scheduler_patchdata = solver_graph.register_node(
178 "set_scheduler_patchdata",
179 NodeSetEdge<PatchDataLayerRefs>([&](PatchDataLayerRefs &scheduler_patchdata) {
180 scheduler_patchdata.free_alloc();
181 scheduler().for_each_patchdata_nonempty(
182 [&](const shamrock::patch::Patch &p,
184 scheduler_patchdata.patchdatas.add_obj(p.id_patch, std::ref(pdat));
185 });
186 }));
187 shambase::get_check_ref(set_scheduler_patchdata)
188 .set_edges(solver_graph.get_edge_ptr<PatchDataLayerRefs>("scheduler_patchdata"));
189 attach_field_sequence.push_back(set_scheduler_patchdata);
190 }
191
192 {
193 auto attach_part_counts
194 = solver_graph.register_node("attach_part_counts", GetObjCntFromLayer{});
195 shambase::get_check_ref(attach_part_counts)
196 .set_edges(
197 solver_graph.get_edge_ptr<PatchDataLayerRefs>("scheduler_patchdata"),
198 solver_graph.get_edge_ptr<Indexes<u32>>("part_counts"));
199 attach_field_sequence.push_back(attach_part_counts);
200 }
201
202 {
203 auto attach_xyz
204 = solver_graph.register_node("attach_xyz", GetFieldRefFromLayer<Tvec>(pdl, "xyz"));
205 shambase::get_check_ref(attach_xyz)
206 .set_edges(
207 solver_graph.get_edge_ptr<PatchDataLayerRefs>("scheduler_patchdata"),
208 solver_graph.get_edge_ptr<FieldRefs<Tvec>>("xyz"));
209 attach_field_sequence.push_back(attach_xyz);
210 }
211
212 {
213 auto attach_vxyz = solver_graph.register_node(
214 "attach_vxyz", GetFieldRefFromLayer<Tvec>(pdl, "vxyz"));
215 shambase::get_check_ref(attach_vxyz)
216 .set_edges(
217 solver_graph.get_edge_ptr<PatchDataLayerRefs>("scheduler_patchdata"),
218 solver_graph.get_edge_ptr<FieldRefs<Tvec>>("vxyz"));
219 attach_field_sequence.push_back(attach_vxyz);
220 }
221
222 {
223 auto attach_axyz = solver_graph.register_node(
224 "attach_axyz", GetFieldRefFromLayer<Tvec>(pdl, "axyz"));
225 shambase::get_check_ref(attach_axyz)
226 .set_edges(
227 solver_graph.get_edge_ptr<PatchDataLayerRefs>("scheduler_patchdata"),
228 solver_graph.get_edge_ptr<FieldRefs<Tvec>>("axyz"));
229 attach_field_sequence.push_back(attach_axyz);
230 }
231
232 {
233 auto attach_uint = solver_graph.register_node(
234 "attach_uint", GetFieldRefFromLayer<Tscal>(pdl, "uint"));
235 shambase::get_check_ref(attach_uint)
236 .set_edges(
237 solver_graph.get_edge_ptr<PatchDataLayerRefs>("scheduler_patchdata"),
238 solver_graph.get_edge_ptr<FieldRefs<Tscal>>("uint"));
239 attach_field_sequence.push_back(attach_uint);
240 }
241
242 {
243 auto attach_duint = solver_graph.register_node(
244 "attach_duint", GetFieldRefFromLayer<Tscal>(pdl, "duint"));
245 shambase::get_check_ref(attach_duint)
246 .set_edges(
247 solver_graph.get_edge_ptr<PatchDataLayerRefs>("scheduler_patchdata"),
248 solver_graph.get_edge_ptr<FieldRefs<Tscal>>("duint"));
249 attach_field_sequence.push_back(attach_duint);
250 }
251
252 {
253 auto attach_hpart = solver_graph.register_node(
254 "attach_hpart", GetFieldRefFromLayer<Tscal>(pdl, "hpart"));
255 shambase::get_check_ref(attach_hpart)
256 .set_edges(
257 solver_graph.get_edge_ptr<PatchDataLayerRefs>("scheduler_patchdata"),
258 solver_graph.get_edge_ptr<FieldRefs<Tscal>>("hpart"));
259 attach_field_sequence.push_back(attach_hpart);
260 }
261
262 if (has_B_field) {
263 auto attach_B_on_rho = solver_graph.register_node(
264 "attach_B_on_rho", GetFieldRefFromLayer<Tvec>(pdl, "B/rho"));
265 shambase::get_check_ref(attach_B_on_rho)
266 .set_edges(
267 solver_graph.get_edge_ptr<PatchDataLayerRefs>("scheduler_patchdata"),
268 solver_graph.get_edge_ptr<FieldRefs<Tvec>>("B/rho"));
269 attach_field_sequence.push_back(attach_B_on_rho);
270 }
271
272 if (has_B_field) {
273 auto attach_dB_on_rho = solver_graph.register_node(
274 "attach_dB_on_rho", GetFieldRefFromLayer<Tvec>(pdl, "dB/rho"));
275 shambase::get_check_ref(attach_dB_on_rho)
276 .set_edges(
277 solver_graph.get_edge_ptr<PatchDataLayerRefs>("scheduler_patchdata"),
278 solver_graph.get_edge_ptr<FieldRefs<Tvec>>("dB/rho"));
279 attach_field_sequence.push_back(attach_dB_on_rho);
280 }
281
282 if (has_psi_field) {
283 auto attach_psi_on_ch = solver_graph.register_node(
284 "attach_psi_on_ch", GetFieldRefFromLayer<Tscal>(pdl, "psi/ch"));
285 shambase::get_check_ref(attach_psi_on_ch)
286 .set_edges(
287 solver_graph.get_edge_ptr<PatchDataLayerRefs>("scheduler_patchdata"),
288 solver_graph.get_edge_ptr<FieldRefs<Tscal>>("psi/ch"));
289 attach_field_sequence.push_back(attach_psi_on_ch);
290 }
291
292 if (has_psi_field) {
293 auto attach_dpsi_on_ch = solver_graph.register_node(
294 "attach_dpsi_on_ch", GetFieldRefFromLayer<Tscal>(pdl, "dpsi/ch"));
295 shambase::get_check_ref(attach_dpsi_on_ch)
296 .set_edges(
297 solver_graph.get_edge_ptr<PatchDataLayerRefs>("scheduler_patchdata"),
298 solver_graph.get_edge_ptr<FieldRefs<Tscal>>("dpsi/ch"));
299 attach_field_sequence.push_back(attach_dpsi_on_ch);
300 }
301
302 if (has_epsilon_field) {
303 auto attach_epsilon = solver_graph.register_node(
304 "attach_epsilon", GetFieldRefFromLayer<Tscal>(pdl, "epsilon"));
305 shambase::get_check_ref(attach_epsilon)
306 .set_edges(
307 solver_graph.get_edge_ptr<PatchDataLayerRefs>("scheduler_patchdata"),
308 solver_graph.get_edge_ptr<FieldRefs<Tscal>>("epsilon"));
309 attach_field_sequence.push_back(attach_epsilon);
310 }
311
312 if (has_epsilon_field) {
313 auto attach_dtepsilon = solver_graph.register_node(
314 "attach_dtepsilon", GetFieldRefFromLayer<Tscal>(pdl, "dtepsilon"));
315 shambase::get_check_ref(attach_dtepsilon)
316 .set_edges(
317 solver_graph.get_edge_ptr<PatchDataLayerRefs>("scheduler_patchdata"),
318 solver_graph.get_edge_ptr<FieldRefs<Tscal>>("dtepsilon"));
319 attach_field_sequence.push_back(attach_dtepsilon);
320 }
321
322 if (has_deltav_field) {
323 auto attach_deltav = solver_graph.register_node(
324 "attach_deltav", GetFieldRefFromLayer<Tvec>(pdl, "deltav"));
325 shambase::get_check_ref(attach_deltav)
326 .set_edges(
327 solver_graph.get_edge_ptr<PatchDataLayerRefs>("scheduler_patchdata"),
328 solver_graph.get_edge_ptr<FieldRefs<Tvec>>("deltav"));
329 attach_field_sequence.push_back(attach_deltav);
330 }
331
332 if (has_deltav_field) {
333 auto attach_dtdeltav = solver_graph.register_node(
334 "attach_dtdeltav", GetFieldRefFromLayer<Tvec>(pdl, "dtdeltav"));
335 shambase::get_check_ref(attach_dtdeltav)
336 .set_edges(
337 solver_graph.get_edge_ptr<PatchDataLayerRefs>("scheduler_patchdata"),
338 solver_graph.get_edge_ptr<FieldRefs<Tvec>>("dtdeltav"));
339 attach_field_sequence.push_back(attach_dtdeltav);
340 }
341
342 if (has_s_j_field) {
343 auto attach_s_j
344 = solver_graph.register_node("attach_s_j", GetFieldRefFromLayer<Tscal>(pdl, "s_j"));
345 shambase::get_check_ref(attach_s_j)
346 .set_edges(
347 solver_graph.get_edge_ptr<PatchDataLayerRefs>("scheduler_patchdata"),
348 solver_graph.get_edge_ptr<FieldRefs<Tscal>>("s_j"));
349 attach_field_sequence.push_back(attach_s_j);
350 }
351
352 if (has_s_j_field) {
353 auto attach_ds_j_dt = solver_graph.register_node(
354 "attach_ds_j_dt", GetFieldRefFromLayer<Tscal>(pdl, "ds_j_dt"));
355 shambase::get_check_ref(attach_ds_j_dt)
356 .set_edges(
357 solver_graph.get_edge_ptr<PatchDataLayerRefs>("scheduler_patchdata"),
358 solver_graph.get_edge_ptr<FieldRefs<Tscal>>("ds_j_dt"));
359 attach_field_sequence.push_back(attach_ds_j_dt);
360 }
361 solver_graph.register_node(
362 "attach fields to scheduler",
363 OperationSequence("attach fields", std::move(attach_field_sequence)));
364 }
365
367 // leapfrog predictor
369
370 {
371
372 auto make_half_step_sequence = [&](std::string prefix) {
373 std::vector<std::shared_ptr<shamrock::solvergraph::INode>> half_step_sequence;
374
375 {
376 auto half_step_vxyz = solver_graph.register_node(
378 shambase::get_check_ref(half_step_vxyz)
379 .set_edges(
380 solver_graph.get_edge_ptr<IDataEdge<Tscal>>("dt_half"),
381 solver_graph.get_edge_ptr<FieldRefs<Tvec>>("axyz"),
382 solver_graph.get_edge_ptr<Indexes<u32>>("part_counts"),
383 solver_graph.get_edge_ptr<FieldRefs<Tvec>>("vxyz"));
384 half_step_sequence.push_back(half_step_vxyz);
385 }
386
387 {
388 auto half_step_uint = solver_graph.register_node(
390 shambase::get_check_ref(half_step_uint)
391 .set_edges(
392 solver_graph.get_edge_ptr<IDataEdge<Tscal>>("dt_half"),
393 solver_graph.get_edge_ptr<FieldRefs<Tscal>>("duint"),
394 solver_graph.get_edge_ptr<Indexes<u32>>("part_counts"),
395 solver_graph.get_edge_ptr<FieldRefs<Tscal>>("uint"));
396 half_step_sequence.push_back(half_step_uint);
397 }
398
399 if (has_B_field) {
400 auto half_step_B_on_rho = solver_graph.register_node(
402 shambase::get_check_ref(half_step_B_on_rho)
403 .set_edges(
404 solver_graph.get_edge_ptr<IDataEdge<Tscal>>("dt_half"),
405 solver_graph.get_edge_ptr<FieldRefs<Tvec>>("dB/rho"),
406 solver_graph.get_edge_ptr<Indexes<u32>>("part_counts"),
407 solver_graph.get_edge_ptr<FieldRefs<Tvec>>("B/rho"));
408 half_step_sequence.push_back(half_step_B_on_rho);
409 }
410
411 if (has_psi_field) {
412 auto half_step_psi_on_ch = solver_graph.register_node(
413 prefix + "_psi_on_ch", shammodels::common::modules::ForwardEuler<Tscal>{});
414 shambase::get_check_ref(half_step_psi_on_ch)
415 .set_edges(
416 solver_graph.get_edge_ptr<IDataEdge<Tscal>>("dt_half"),
417 solver_graph.get_edge_ptr<FieldRefs<Tscal>>("dpsi/ch"),
418 solver_graph.get_edge_ptr<Indexes<u32>>("part_counts"),
419 solver_graph.get_edge_ptr<FieldRefs<Tscal>>("psi/ch"));
420 half_step_sequence.push_back(half_step_psi_on_ch);
421 }
422
423 if (has_epsilon_field) {
424 auto half_step_epsilon = solver_graph.register_node(
426 shambase::get_check_ref(half_step_epsilon)
427 .set_edges(
428 solver_graph.get_edge_ptr<IDataEdge<Tscal>>("dt_half"),
429 solver_graph.get_edge_ptr<FieldRefs<Tscal>>("dtepsilon"),
430 solver_graph.get_edge_ptr<Indexes<u32>>("part_counts"),
431 solver_graph.get_edge_ptr<FieldRefs<Tscal>>("epsilon"));
432 half_step_sequence.push_back(half_step_epsilon);
433 }
434
435 if (has_deltav_field) {
436 auto half_step_deltav = solver_graph.register_node(
438 shambase::get_check_ref(half_step_deltav)
439 .set_edges(
440 solver_graph.get_edge_ptr<IDataEdge<Tscal>>("dt_half"),
441 solver_graph.get_edge_ptr<FieldRefs<Tvec>>("dtdeltav"),
442 solver_graph.get_edge_ptr<Indexes<u32>>("part_counts"),
443 solver_graph.get_edge_ptr<FieldRefs<Tvec>>("deltav"));
444 half_step_sequence.push_back(half_step_deltav);
445 }
446
447 if (has_s_j_field) {
448 u32 ndust = solver_config.dust_config.get_dust_nvar();
449 auto half_step_s_j = solver_graph.register_node(
451 shambase::get_check_ref(half_step_s_j)
452 .set_edges(
453 solver_graph.get_edge_ptr<IDataEdge<Tscal>>("dt_half"),
454 solver_graph.get_edge_ptr<FieldRefs<Tscal>>("ds_j_dt"),
455 solver_graph.get_edge_ptr<Indexes<u32>>("part_counts"),
456 solver_graph.get_edge_ptr<FieldRefs<Tscal>>("s_j"));
457 half_step_sequence.push_back(half_step_s_j);
458 }
459
460 return OperationSequence("half step", std::move(half_step_sequence));
461 };
462
463 solver_graph.register_node("half_step1", make_half_step_sequence("half_step1"));
464 solver_graph.register_node("half_step2", make_half_step_sequence("half_step2"));
465
466 {
467 auto full_step_xyz = solver_graph.register_node(
469 shambase::get_check_ref(full_step_xyz)
470 .set_edges(
471 solver_graph.get_edge_ptr<IDataEdge<Tscal>>("dt"),
472 solver_graph.get_edge_ptr<FieldRefs<Tvec>>("vxyz"),
473 solver_graph.get_edge_ptr<Indexes<u32>>("part_counts"),
474 solver_graph.get_edge_ptr<FieldRefs<Tvec>>("xyz"));
475 }
476
477 {
478 auto leapfrog_predictor = solver_graph.register_node(
479 "leapfrog predictor",
481 "leapfrog predictor",
482 {
483 solver_graph.get_node_ptr_base("half_step1"),
484 solver_graph.get_node_ptr_base("full_step_xyz"),
485 solver_graph.get_node_ptr_base("half_step2"),
486 }));
487 }
488 }
489
491 // Part killing step
493 bool do_part_killing_step = solver_config.particle_killing.kill_list.size() > 0;
494
495 if (do_part_killing_step) {
496
497 auto patchdatas = solver_graph.get_edge_ptr<PatchDataLayerRefs>("scheduler_patchdata");
498 auto xyz_edge = solver_graph.get_edge_ptr<FieldRefs<Tvec>>("xyz");
499
500 auto part_to_remove = solver_graph.register_edge(
501 "part_to_remove", DistributedBuffers<u32>("part_to_remove", "part_to_remove"));
502
503 std::vector<std::shared_ptr<shamrock::solvergraph::INode>> part_kill_sequence{};
504
505 {
506
507 auto empty_part_to_remove
508 = solver_graph.register_node("empty_part_to_remove", NodeFreeAlloc{});
509 shambase::get_check_ref(empty_part_to_remove).set_edges(part_to_remove);
510 part_kill_sequence.push_back(empty_part_to_remove);
511 }
512
513 using kill_t = typename ParticleKillingConfig<Tvec>::kill_t;
514 using kill_sphere = typename ParticleKillingConfig<Tvec>::Sphere;
515
516 // selectors
517 for (kill_t &kill_obj : solver_config.particle_killing.kill_list) {
518 if (kill_sphere *kill_info = std::get_if<kill_sphere>(&kill_obj)) {
519
521 kill_info->center, kill_info->radius);
522 node_selector.set_edges(xyz_edge, part_to_remove);
523
524 part_kill_sequence.push_back(
525 std::make_shared<decltype(node_selector)>(std::move(node_selector)));
526 }
527 }
528
529 { // killing
530 modules::KillParticles node_killer{};
531 node_killer.set_edges(part_to_remove, patchdatas);
532
533 part_kill_sequence.push_back(
534 std::make_shared<decltype(node_killer)>(std::move(node_killer)));
535 }
536
537 solver_graph.register_node(
538 "part killing step",
539 OperationSequence("part killing step", std::move(part_kill_sequence)));
540 }
541
542 {
543 std::vector<std::shared_ptr<shamrock::solvergraph::INode>> seq{};
544
545 seq.push_back(solver_graph.get_node_ptr_base("set_gpart_mass"));
546 seq.push_back(solver_graph.get_node_ptr_base("attach fields to scheduler"));
547 seq.push_back(solver_graph.get_node_ptr_base("leapfrog predictor"));
548 if (do_part_killing_step) {
549 seq.push_back(solver_graph.get_node_ptr_base("part killing step"));
550 }
551
552 storage.solver_sequence = solver_graph.register_node(
553 "time_step", OperationSequence("time step", std::move(seq)));
554 }
555
556 storage.part_counts
557 = std::make_shared<shamrock::solvergraph::Indexes<u32>>("part_counts", "N_{\\rm part}");
558
559 storage.part_counts_with_ghost = std::make_shared<shamrock::solvergraph::Indexes<u32>>(
560 "part_counts_with_ghost", "N_{\\rm part, with ghost}");
561
562 storage.patch_rank_owner = std::make_shared<shamrock::solvergraph::RankGetter>(
563 [&](u64 patch_id) -> u32 {
564 return scheduler().get_patch_rank_owner(patch_id);
565 },
566 "patch_rank_owner",
567 "rank");
568
569 // merged ghost spans
570 storage.positions_with_ghosts
571 = std::make_shared<shamrock::solvergraph::FieldRefs<Tvec>>("part_pos", "\\mathbf{r}");
572 storage.hpart_with_ghosts
573 = std::make_shared<shamrock::solvergraph::FieldRefs<Tscal>>("h_part", "h");
574
575 storage.neigh_cache
576 = std::make_shared<shammodels::sph::solvergraph::NeighCache>("neigh_cache", "neigh");
577
578 storage.omega = std::make_shared<shamrock::solvergraph::Field<Tscal>>(1, "omega", "\\Omega");
579
580 if (solver_config.has_field_alphaAV()) {
581 storage.alpha_av_updated = std::make_shared<shamrock::solvergraph::Field<Tscal>>(
582 1, "alpha_av_updated", "\\alpha_{\\rm AV}");
583 }
584
585 storage.pressure = std::make_shared<shamrock::solvergraph::Field<Tscal>>(1, "pressure", "P");
586 storage.soundspeed
587 = std::make_shared<shamrock::solvergraph::Field<Tscal>>(1, "soundspeed", "c_s");
588
589 storage.exchange_gz_alpha
590 = std::make_shared<shamrock::solvergraph::ExchangeGhostField<Tscal>>();
591 storage.exchange_gz_node
592 = std::make_shared<shamrock::solvergraph::ExchangeGhostLayer>(storage.ghost_layout);
593 storage.exchange_gz_positions
594 = std::make_shared<shamrock::solvergraph::ExchangeGhostLayer>(storage.xyzh_ghost_layout);
595}
596
597template<class Tvec, template<class> class Kern>
599 std::string filename, bool add_patch_world_id) {
600
601 modules::VTKDump(context, solver_config).do_dump(filename, add_patch_world_id);
602}
603
605// Debug interface dump
607
608namespace shammodels::sph {
609
610 template<class Tvec>
612 using Tscal = shambase::VecComponent<Tvec>;
613
614 u64 nobj;
615 f64 gpart_mass;
616
617 sycl::buffer<Tvec> &buf_xyz;
618 sycl::buffer<Tscal> &buf_hpart;
619 sycl::buffer<Tvec> &buf_vxyz;
620 };
621
622 template<class Tvec>
623 void fill_blocks(PhantomDumpBlock &block, Debug_ph_dump<Tvec> &info) {
624
625 using Tscal = shambase::VecComponent<Tvec>;
626 std::vector<Tvec> xyz = shamalgs::memory::buf_to_vec(info.buf_xyz, info.nobj);
627
628 u64 xid = block.get_ref_fort_real("x");
629 u64 yid = block.get_ref_fort_real("y");
630 u64 zid = block.get_ref_fort_real("z");
631
632 for (auto vec : xyz) {
633 block.blocks_fort_real[xid].vals.push_back(vec.x());
634 block.blocks_fort_real[yid].vals.push_back(vec.y());
635 block.blocks_fort_real[zid].vals.push_back(vec.z());
636 }
637
638 std::vector<Tscal> h = shamalgs::memory::buf_to_vec(info.buf_hpart, info.nobj);
639 u64 hid = block.get_ref_f32("h");
640 for (auto h_ : h) {
641 block.blocks_f32[hid].vals.push_back(h_);
642 }
643
644 std::vector<Tvec> vxyz = shamalgs::memory::buf_to_vec(info.buf_vxyz, info.nobj);
645
646 u64 vxid = block.get_ref_fort_real("vx");
647 u64 vyid = block.get_ref_fort_real("vy");
648 u64 vzid = block.get_ref_fort_real("vz");
649
650 for (auto vec : vxyz) {
651 block.blocks_fort_real[vxid].vals.push_back(vec.x());
652 block.blocks_fort_real[vyid].vals.push_back(vec.y());
653 block.blocks_fort_real[vzid].vals.push_back(vec.z());
654 }
655
656 block.tot_count = block.blocks_fort_real[xid].vals.size();
657 }
658
659 template<class Tvec>
660 shammodels::sph::PhantomDump make_interface_debug_phantom_dump(Debug_ph_dump<Tvec> info) {
661
662 using Tscal = shambase::VecComponent<Tvec>;
663 PhantomDump dump;
664
666 dump.iversion = 1;
667 dump.fileid = shambase::format("{:100s}", "FT:Phantom Shamrock writer");
668
669 u32 Ntot = info.nobj;
670 dump.table_header_fort_int.add("nparttot", Ntot);
671 dump.table_header_fort_int.add("ntypes", 8);
672 dump.table_header_fort_int.add("npartoftype", Ntot);
673 dump.table_header_fort_int.add("npartoftype", 0);
674 dump.table_header_fort_int.add("npartoftype", 0);
675 dump.table_header_fort_int.add("npartoftype", 0);
676 dump.table_header_fort_int.add("npartoftype", 0);
677 dump.table_header_fort_int.add("npartoftype", 0);
678 dump.table_header_fort_int.add("npartoftype", 0);
679 dump.table_header_fort_int.add("npartoftype", 0);
680
681 dump.table_header_i64.add("nparttot", Ntot);
682 dump.table_header_i64.add("ntypes", 8);
683 dump.table_header_i64.add("npartoftype", Ntot);
684 dump.table_header_i64.add("npartoftype", 0);
685 dump.table_header_i64.add("npartoftype", 0);
686 dump.table_header_i64.add("npartoftype", 0);
687 dump.table_header_i64.add("npartoftype", 0);
688 dump.table_header_i64.add("npartoftype", 0);
689 dump.table_header_i64.add("npartoftype", 0);
690 dump.table_header_i64.add("npartoftype", 0);
691
692 dump.table_header_fort_int.add("nblocks", 1);
693 dump.table_header_fort_int.add("nptmass", 0);
694 dump.table_header_fort_int.add("ndustlarge", 0);
695 dump.table_header_fort_int.add("ndustsmall", 0);
696 dump.table_header_fort_int.add("idust", 7);
697 dump.table_header_fort_int.add("idtmax_n", 1);
698 dump.table_header_fort_int.add("idtmax_frac", 0);
699 dump.table_header_fort_int.add("idumpfile", 0);
700 dump.table_header_fort_int.add("majorv", 2023);
701 dump.table_header_fort_int.add("minorv", 0);
702 dump.table_header_fort_int.add("microv", 0);
703 dump.table_header_fort_int.add("isink", 0);
704
705 dump.table_header_i32.add("iexternalforce", 0);
706 dump.table_header_i32.add("ieos", 2);
707 dump.table_header_fort_real.add("gamma", 1.66667);
708 dump.table_header_fort_real.add("RK2", 0);
709 dump.table_header_fort_real.add("polyk2", 0);
710 dump.table_header_fort_real.add("qfacdisc", 0.75);
711 dump.table_header_fort_real.add("qfacdisc2", 0.75);
712
713 dump.table_header_fort_real.add("time", 0);
714 dump.table_header_fort_real.add("dtmax", 0.1);
715
716 dump.table_header_fort_real.add("rhozero", 0);
717 dump.table_header_fort_real.add("hfact", 1.2);
718 dump.table_header_fort_real.add("tolh", 0.0001);
719 dump.table_header_fort_real.add("C_cour", 0);
720 dump.table_header_fort_real.add("C_force", 0);
721 dump.table_header_fort_real.add("alpha", 0);
722 dump.table_header_fort_real.add("alphau", 1);
723 dump.table_header_fort_real.add("alphaB", 1);
724
725 dump.table_header_fort_real.add("massoftype", info.gpart_mass);
726 dump.table_header_fort_real.add("massoftype", 0);
727 dump.table_header_fort_real.add("massoftype", 0);
728 dump.table_header_fort_real.add("massoftype", 0);
729 dump.table_header_fort_real.add("massoftype", 0);
730 dump.table_header_fort_real.add("massoftype", 0);
731 dump.table_header_fort_real.add("massoftype", 0);
732 dump.table_header_fort_real.add("massoftype", 0);
733
734 dump.table_header_fort_real.add("Bextx", 0);
735 dump.table_header_fort_real.add("Bexty", 0);
736 dump.table_header_fort_real.add("Bextz", 0);
737 dump.table_header_fort_real.add("dum", 0);
738
739 dump.table_header_fort_real.add("get_conserv", -1);
740 dump.table_header_fort_real.add("etot_in", 0.59762);
741 dump.table_header_fort_real.add("angtot_in", 0.0189694);
742 dump.table_header_fort_real.add("totmom_in", 0.0306284);
743
744 dump.table_header_f64.add("udist", 1);
745 dump.table_header_f64.add("umass", 1);
746 dump.table_header_f64.add("utime", 1);
747 dump.table_header_f64.add("umagfd", 3.54491);
748
749 PhantomDumpBlock block_part;
750
751 fill_blocks(block_part, info);
752
753 dump.blocks.push_back(std::move(block_part));
754
755 return dump;
756 }
757
758} // namespace shammodels::sph
759
760template<class Tvec, template<class> class Kern>
761void shammodels::sph::Solver<Tvec, Kern>::gen_serial_patch_tree() {
762 StackEntry stack_loc{};
763
764 SerialPatchTree<Tvec> _sptree = SerialPatchTree<Tvec>::build(scheduler());
765 _sptree.attach_buf();
766 storage.serial_patch_tree.set(std::move(_sptree));
767}
768
774template<class Tvec, template<class> class Kern>
776
777 StackEntry stack_loc{};
778
779 shamlog_debug_ln("SphSolver", "apply position boundary");
780
781 PatchScheduler &sched = scheduler();
782
783 shamrock::SchedulerUtility integrators(sched);
785
786 auto &pdl = sched.pdl_old();
787
788 const u32 ixyz = pdl.get_field_idx<Tvec>("xyz");
789 const u32 ivxyz = pdl.get_field_idx<Tvec>("vxyz");
790 auto [bmin, bmax] = sched.get_box_volume<Tvec>();
791
792 using SolverConfigBC = typename Config::BCConfig;
793 using SolverBCFree = typename SolverConfigBC::Free;
794 using SolverBCPeriodic = typename SolverConfigBC::Periodic;
795 using SolverBCShearingPeriodic = typename SolverConfigBC::ShearingPeriodic;
796 if (SolverBCFree *c = std::get_if<SolverBCFree>(&solver_config.boundary_config.config)) {
797 if (shamcomm::world_rank() == 0) {
798 logger::info_ln("PositionUpdated", "free boundaries skipping geometry update");
799 }
800 } else if (
801 SolverBCPeriodic *c
802 = std::get_if<SolverBCPeriodic>(&solver_config.boundary_config.config)) {
803 integrators.fields_apply_periodicity(ixyz, std::pair{bmin, bmax});
804 } else if (
805 SolverBCShearingPeriodic *c
806 = std::get_if<SolverBCShearingPeriodic>(&solver_config.boundary_config.config)) {
807 integrators.fields_apply_shearing_periodicity(
808 ixyz,
809 ivxyz,
810 std::pair{bmin, bmax},
811 c->shear_base,
812 c->shear_dir,
813 c->shear_speed * time_val,
814 c->shear_speed);
815 }
816
817 reatrib.reatribute_patch_objects(storage.serial_patch_tree.get(), "xyz");
818}
819
820template<class Tvec, template<class> class Kern>
822
823 StackEntry stack_loc{};
824
825 using SPHUtils = sph::SPHUtilities<Tvec, Kernel>;
826 SPHUtils sph_utils(scheduler());
827
828 storage.ghost_patch_cache.set(sph_utils.build_interf_cache(
829 storage.ghost_handler.get(),
830 storage.serial_patch_tree.get(),
831 solver_config.htol_up_coarse_cycle));
832
833 // storage.ghost_handler.get().gen_debug_patch_ghost(storage.ghost_patch_cache.get());
834}
835
836template<class Tvec, template<class> class Kern>
838 StackEntry stack_loc{};
839 storage.ghost_patch_cache.reset();
840}
841
842template<class Tvec, template<class> class Kern>
844
845 StackEntry stack_loc{};
846
847 storage.merged_xyzh.set(storage.ghost_handler.get().build_comm_merge_positions(
848 storage.ghost_patch_cache.get(),
849 storage.exchange_gz_positions,
850 solver_config.show_ghost_zone_graph));
851
852 { // set element counts
853 shambase::get_check_ref(storage.part_counts).indexes
854 = storage.merged_xyzh.get().template map<u32>(
855 [&](u64 id, shamrock::patch::PatchDataLayer &mpdat) {
856 return scheduler().patch_data.get_pdat(id).get_obj_cnt();
857 });
858 }
859
860 { // set element counts
861 shambase::get_check_ref(storage.part_counts_with_ghost).indexes
862 = storage.merged_xyzh.get().template map<u32>(
863 [&](u64 id, shamrock::patch::PatchDataLayer &mpdat) {
864 return mpdat.get_obj_cnt();
865 });
866 }
867
868 { // Attach spans to block coords
869 shambase::get_check_ref(storage.positions_with_ghosts)
870 .set_refs(storage.merged_xyzh.get()
871 .template map<std::reference_wrapper<PatchDataField<Tvec>>>(
872 [&](u64 id, shamrock::patch::PatchDataLayer &mpdat) {
873 return std::ref(mpdat.get_field<Tvec>(0));
874 }));
875
876 shambase::get_check_ref(storage.hpart_with_ghosts)
877 .set_refs(storage.merged_xyzh.get()
878 .template map<std::reference_wrapper<PatchDataField<Tscal>>>(
879 [&](u64 id, shamrock::patch::PatchDataLayer &mpdat) {
880 return std::ref(mpdat.get_field<Tscal>(1));
881 }));
882 }
883}
884
885template<class Tvec, template<class> class Kern>
889
890template<class Tvec, template<class> class Kern>
892 StackEntry stack_loc{};
893 storage.merged_pos_trees.reset();
894}
895
896template<class Tvec, template<class> class Kern>
898 StackEntry stack_loc{};
899
900 using namespace shamrock;
901 using namespace shamrock::patch;
902
904 using SPHUtils = sph::SPHUtilities<Tvec, Kernel>;
905
906 SPHUtils sph_utils(scheduler());
907 shamrock::SchedulerUtility utility(scheduler());
908
909 PatchDataLayerLayout &pdl = scheduler().pdl_old();
910 const u32 ihpart = pdl.get_field_idx<Tscal>("hpart");
911
912 ComputeField<Tscal> _epsilon_h, _h_old;
913
914 auto should_set_omega_mask = std::make_shared<shamrock::solvergraph::Field<u32>>(
915 1, "should_set_omega_mask", "should_set_omega_mask");
916
917 u32 hstep_cnt = 0;
918 u32 hstep_max = solver_config.h_max_subcycles_count;
919 for (; hstep_cnt < hstep_max; hstep_cnt++) {
920
921 gen_ghost_handler(time_val + dt);
927
928 _epsilon_h = utility.make_compute_field<Tscal>("epsilon_h", 1, Tscal(100));
929 _h_old = utility.save_field<Tscal>(ihpart, "h_old");
930
931 Tscal max_eps_h;
932
933 if (solver_config.gpart_mass == 0) {
935 "invalid gpart_mass {}, this configuration can not converge.\n"
936 "Please set it using either model.set_particle_mass(pmass) or "
937 "cfg.set_particle_mass(pmass)",
938 solver_config.gpart_mass));
939 }
940
941 // sizes
942 std::shared_ptr<shamrock::solvergraph::Indexes<u32>> sizes
943 = std::make_shared<shamrock::solvergraph::Indexes<u32>>("", "");
944 scheduler().for_each_patchdata_nonempty([&](const Patch p, PatchDataLayer &pdat) {
945 sizes->indexes.add_obj(p.id_patch, pdat.get_obj_cnt());
946 });
947
948 // neigh cache
949 auto &neigh_cache = storage.neigh_cache;
950
951 // positions
952 auto &pos_merged = storage.positions_with_ghosts;
953
954 // old smoothing length field
955 std::shared_ptr<shamrock::solvergraph::FieldRefs<Tscal>> hold
956 = std::make_shared<shamrock::solvergraph::FieldRefs<Tscal>>("", "");
958 scheduler().for_each_patchdata_nonempty([&](const Patch p, PatchDataLayer &pdat) {
959 auto &field = _h_old.get_field(p.id_patch);
960 hold_refs.add_obj(p.id_patch, std::ref(field));
961 });
962 hold->set_refs(hold_refs);
963
964 // new smoothing length field
965 std::shared_ptr<shamrock::solvergraph::FieldRefs<Tscal>> hnew
966 = std::make_shared<shamrock::solvergraph::FieldRefs<Tscal>>("", "");
968 scheduler().for_each_patchdata_nonempty([&](const Patch p, PatchDataLayer &pdat) {
969 auto &field = pdat.get_field<Tscal>(ihpart);
970 hnew_refs.add_obj(p.id_patch, std::ref(field));
971 });
972 hnew->set_refs(hnew_refs);
973
974 // epsilon field
975 std::shared_ptr<shamrock::solvergraph::FieldRefs<Tscal>> eps_h
976 = std::make_shared<shamrock::solvergraph::FieldRefs<Tscal>>("", "");
978 scheduler().for_each_patchdata_nonempty([&](const Patch p, PatchDataLayer &pdat) {
979 auto &field = _epsilon_h.get_field(p.id_patch);
980 eps_h_refs.add_obj(p.id_patch, std::ref(field));
981 });
982 eps_h->set_refs(eps_h_refs);
983
984 std::shared_ptr<shamrock::solvergraph::INode> smth_h_iter_ptr;
985
986 using h_conf_density_based = typename SmoothingLengthConfig::DensityBased;
987 using h_conf_neigh_lim = typename SmoothingLengthConfig::DensityBasedNeighLim;
988
989 if (h_conf_density_based *conf
990 = std::get_if<h_conf_density_based>(&solver_config.smoothing_length_config.config)) {
991 std::shared_ptr<shammodels::sph::modules::IterateSmoothingLengthDensity<Tvec, Kernel>>
992 smth_h_iter = std::make_shared<
994 solver_config.gpart_mass,
995 solver_config.htol_up_coarse_cycle,
996 solver_config.htol_up_fine_cycle);
997 smth_h_iter->set_edges(sizes, neigh_cache, pos_merged, hold, hnew, eps_h);
998 smth_h_iter_ptr = smth_h_iter;
999 } else if (
1000 h_conf_neigh_lim *conf
1001 = std::get_if<h_conf_neigh_lim>(&solver_config.smoothing_length_config.config)) {
1002 std::shared_ptr<
1004 smth_h_iter_neigh_lim = std::make_shared<
1006 solver_config.gpart_mass,
1007 solver_config.htol_up_coarse_cycle,
1008 solver_config.htol_up_fine_cycle,
1009 conf->max_neigh_count);
1010 smth_h_iter_neigh_lim->set_edges(
1011 sizes, neigh_cache, pos_merged, hold, hnew, eps_h, should_set_omega_mask);
1012 smth_h_iter_ptr = smth_h_iter_neigh_lim;
1013 } else {
1014 shambase::throw_with_loc<std::runtime_error>("Invalid smoothing length configuration");
1015 }
1016 // iterate smoothing length
1017
1018 std::shared_ptr<shamrock::solvergraph::ScalarEdge<bool>> is_converged
1019 = std::make_shared<shamrock::solvergraph::ScalarEdge<bool>>("", "");
1020
1022 smth_h_iter_ptr, solver_config.epsilon_h, solver_config.h_iter_per_subcycles, false);
1023 loop_smth_h_iter.set_edges(eps_h, is_converged);
1024
1025 loop_smth_h_iter.evaluate();
1026
1027 if (!is_converged->value) {
1028
1029 Tscal largest_h = 0;
1030
1031 scheduler().for_each_patchdata_nonempty([&](const Patch p, PatchDataLayer &pdat) {
1032 largest_h = sham::max(largest_h, pdat.get_field<Tscal>(ihpart).compute_max());
1033 });
1034 Tscal global_largest_h = shamalgs::collective::allreduce_max(largest_h);
1035
1036 std::string add_info = "";
1037 u64 cnt_unconverged = 0;
1038 scheduler().for_each_patchdata_nonempty([&](const Patch p, PatchDataLayer &pdat) {
1039 auto res
1040 = _epsilon_h.get_field(p.id_patch).get_ids_buf_where([](auto access, u32 id) {
1041 return access[id] == -1;
1042 });
1043
1044 if (hstep_cnt == hstep_max - 1) {
1045 if (std::get<0>(res)) {
1046 add_info += "\n patch " + std::to_string(p.id_patch) + " ";
1047 add_info += "errored parts : \n";
1048 sycl::buffer<u32> &idx_err = *std::get<0>(res);
1049
1050 sham::DeviceBuffer<Tvec> &xyz = pdat.get_field_buf_ref<Tvec>(0);
1051 sham::DeviceBuffer<Tscal> &hpart = pdat.get_field_buf_ref<Tscal>(ihpart);
1052
1053 auto pos = xyz.copy_to_stdvec();
1054 auto h = hpart.copy_to_stdvec();
1055
1056 {
1057 sycl::host_accessor acc{idx_err};
1058 for (u32 i = 0; i < idx_err.size(); i++) {
1059 add_info += shambase::format(
1060 "{} - pos : {}, hpart : {}\n", acc[i], pos[acc[i]], h[acc[i]]);
1061 }
1062 }
1063 }
1064 }
1065
1066 cnt_unconverged += std::get<1>(res);
1067 });
1068
1069 u64 global_cnt_unconverged = shamalgs::collective::allreduce_sum(cnt_unconverged);
1070
1071 if (shamcomm::world_rank() == 0) {
1073 "Smoothinglength",
1074 "smoothing length is not converged, rerunning the iterator ...\n largest h "
1075 "=",
1076 global_largest_h,
1077 "unconverged cnt =",
1078 global_cnt_unconverged,
1079 add_info);
1080 }
1081
1082 reset_ghost_handler();
1084
1085 shambase::get_check_ref(storage.part_counts).free_alloc();
1086 shambase::get_check_ref(storage.part_counts_with_ghost).free_alloc();
1087 shambase::get_check_ref(storage.positions_with_ghosts).free_alloc();
1088 shambase::get_check_ref(storage.hpart_with_ghosts).free_alloc();
1089
1090 storage.merged_xyzh.reset();
1091
1095
1096 // scheduler().for_each_patchdata_nonempty([&](Patch cur_p, PatchData &pdat) {
1097 // pdat.synchronize_buf();
1098 // });
1099
1100 continue;
1101 }
1102
1103 // The hpart is not valid anymore in ghost zones since we iterated it's value
1104 shambase::get_check_ref(storage.hpart_with_ghosts).free_alloc();
1105
1106 _epsilon_h.reset();
1107 _h_old.reset();
1108 break;
1109 }
1110
1111 if (hstep_cnt == hstep_max) {
1112 logger::err_ln("SPH", "the h iterator is not converged after", hstep_cnt, "iterations");
1113 }
1114
1115 std::shared_ptr<shamrock::solvergraph::FieldRefs<Tscal>> hnew_edge
1116 = std::make_shared<shamrock::solvergraph::FieldRefs<Tscal>>("", "");
1118 scheduler().for_each_patchdata_nonempty([&](const Patch p, PatchDataLayer &pdat) {
1119 auto &field = pdat.get_field<Tscal>(ihpart);
1120 hnew_refs.add_obj(p.id_patch, std::ref(field));
1121 });
1122 hnew_edge->set_refs(hnew_refs);
1123
1124 modules::NodeComputeOmega<Tvec, Kern> compute_omega{solver_config.gpart_mass};
1125 compute_omega.set_edges(
1126 storage.part_counts,
1127 storage.neigh_cache,
1128 storage.positions_with_ghosts,
1129 hnew_edge,
1130 storage.omega);
1131 compute_omega.evaluate();
1132
1133 if (solver_config.smoothing_length_config.is_density_based_neigh_lim()) {
1134 // if the h limiter is triggered, omega does not hold it's sense of dh/dr anymore
1135 // so we set it to 1, this effectively is equivalent of disabling the energy correction
1136 // term corresponding to dh/dr
1137 modules::SetWhenMask<Tscal> set_omega_mask{1};
1138 set_omega_mask.set_edges(storage.part_counts, should_set_omega_mask, storage.omega);
1139 set_omega_mask.evaluate();
1140 }
1141}
1142
1143template<class Tvec, template<class> class Kern>
1145
1146 storage.ghost_layout = std::make_shared<shamrock::patch::PatchDataLayerLayout>();
1147
1149 = shambase::get_check_ref(storage.ghost_layout);
1150
1151 solver_config.set_ghost_layout(ghost_layout);
1152
1153 storage.xyzh_ghost_layout = std::make_shared<shamrock::patch::PatchDataLayerLayout>();
1154 storage.xyzh_ghost_layout->template add_field<Tvec>("xyz", 1);
1155 storage.xyzh_ghost_layout->template add_field<Tscal>("hpart", 1);
1156}
1157
1158template<class Tvec, template<class> class Kern>
1160
1161 StackEntry stack_loc{};
1162
1163 auto &xyzh_merged = storage.merged_xyzh.get();
1164 auto dev_sched = shamsys::instance::get_compute_scheduler_ptr();
1165
1166 storage.rtree_rint_field.set(
1167 storage.merged_pos_trees.get().template map<shamtree::KarrasRadixTreeField<Tscal>>(
1168 [&](u64 id, RTree &rtree) -> shamtree::KarrasRadixTreeField<Tscal> {
1169 shamrock::patch::PatchDataLayer &tmp = xyzh_merged.get(id);
1170 auto &buf = tmp.get_field_buf_ref<Tscal>(1);
1171 auto buf_int = shamtree::new_empty_karras_radix_tree_field<Tscal>();
1172
1173 auto ret = shamtree::compute_tree_field_max_field<Tscal>(
1174 rtree.structure,
1175 rtree.reduced_morton_set.get_leaf_cell_iterator(),
1176 std::move(buf_int),
1177 buf);
1178
1179 // the old tree used to increase the size of the hmax of the tree nodes by the
1180 // tolerance so we do it also with the new tree, maybe we should move that somewhere
1181 // else.
1182 sham::kernel_call(
1183 dev_sched->get_queue(),
1184 sham::MultiRef{},
1185 sham::MultiRef{ret.buf_field},
1186 ret.buf_field.get_size(),
1187 [htol = solver_config.htol_up_coarse_cycle](u32 i, Tscal *h_tree) {
1188 h_tree[i] *= htol;
1189 });
1190
1191 return std::move(ret);
1192 }));
1193}
1194
1195template<class Tvec, template<class> class Kern>
1197 storage.rtree_rint_field.reset();
1198}
1199
1200template<class Tvec, template<class> class Kern>
1202 if (solver_config.use_two_stage_search) {
1204 context, solver_config, storage)
1205 .start_neighbors_cache_2stages();
1206 } else {
1208 context, solver_config, storage)
1209 .start_neighbors_cache();
1210 }
1211
1212 if (solver_config.show_neigh_stats) {
1213 auto &pos_merged = storage.positions_with_ghosts;
1214 auto &neigh_cache = storage.neigh_cache;
1215 auto &hpart_with_ghosts = storage.hpart_with_ghosts;
1216 auto &part_counts = storage.part_counts;
1217
1218 modules::ComputeNeighStats<Tvec> compute_neigh_stats(Kernel::Rkern);
1219
1220 compute_neigh_stats.set_edges(part_counts, neigh_cache, pos_merged, hpart_with_ghosts);
1221 compute_neigh_stats.evaluate();
1222 }
1223}
1224
1225template<class Tvec, template<class> class Kern>
1227 // storage.neighbors_cache.reset();
1228}
1229
1230template<class Tvec, template<class> class Kern>
1232
1233 StackEntry stack_loc{};
1234
1235 shambase::Timer timer_interf;
1236 timer_interf.start();
1237
1238 using namespace shamrock;
1239 using namespace shamrock::patch;
1240
1241 bool has_alphaAV_field = solver_config.has_field_alphaAV();
1242 bool has_soundspeed_field = solver_config.ghost_has_soundspeed();
1243
1244 bool has_B_field = solver_config.has_field_B_on_rho();
1245 bool has_psi_field = solver_config.has_field_psi_on_ch();
1246 bool has_curlB_field = solver_config.has_field_curlB();
1247 bool has_epsilon_field = solver_config.dust_config.has_epsilon_field();
1248 bool has_deltav_field = solver_config.dust_config.has_deltav_field();
1249 bool has_s_j_field = solver_config.dust_config.has_s_j_field();
1250
1251 PatchDataLayerLayout &pdl = scheduler().pdl_old();
1252 const u32 ixyz = pdl.get_field_idx<Tvec>("xyz");
1253 const u32 ivxyz = pdl.get_field_idx<Tvec>("vxyz");
1254 const u32 iaxyz = pdl.get_field_idx<Tvec>("axyz");
1255 const u32 iuint = pdl.get_field_idx<Tscal>("uint");
1256 const u32 iduint = pdl.get_field_idx<Tscal>("duint");
1257 const u32 ihpart = pdl.get_field_idx<Tscal>("hpart");
1258
1259 const u32 ialpha_AV = (has_alphaAV_field) ? pdl.get_field_idx<Tscal>("alpha_AV") : 0;
1260 const u32 isoundspeed = (has_soundspeed_field) ? pdl.get_field_idx<Tscal>("soundspeed") : 0;
1261
1262 const u32 iB_on_rho = (has_B_field) ? pdl.get_field_idx<Tvec>("B/rho") : 0;
1263 const u32 idB_on_rho = (has_B_field) ? pdl.get_field_idx<Tvec>("dB/rho") : 0;
1264 const u32 ipsi_on_ch = (has_psi_field) ? pdl.get_field_idx<Tscal>("psi/ch") : 0;
1265 const u32 idpsi_on_ch = (has_psi_field) ? pdl.get_field_idx<Tscal>("dpsi/ch") : 0;
1266 const u32 icurlB = (has_curlB_field) ? pdl.get_field_idx<Tvec>("curlB") : 0;
1267
1268 bool do_MHD_debug = solver_config.do_MHD_debug();
1269 const u32 imag_pressure = (do_MHD_debug) ? pdl.get_field_idx<Tvec>("mag_pressure") : -1;
1270 const u32 imag_tension = (do_MHD_debug) ? pdl.get_field_idx<Tvec>("mag_tension") : -1;
1271 const u32 igas_pressure = (do_MHD_debug) ? pdl.get_field_idx<Tvec>("gas_pressure") : -1;
1272 const u32 itensile_corr = (do_MHD_debug) ? pdl.get_field_idx<Tvec>("tensile_corr") : -1;
1273 const u32 ipsi_propag = (do_MHD_debug) ? pdl.get_field_idx<Tscal>("psi_propag") : -1;
1274 const u32 ipsi_diff = (do_MHD_debug) ? pdl.get_field_idx<Tscal>("psi_diff") : -1;
1275 const u32 ipsi_cons = (do_MHD_debug) ? pdl.get_field_idx<Tscal>("psi_cons") : -1;
1276 const u32 iu_mhd = (do_MHD_debug) ? pdl.get_field_idx<Tscal>("u_mhd") : -1;
1277
1278 const u32 iepsilon = (has_epsilon_field) ? pdl.get_field_idx<Tscal>("epsilon") : 0;
1279 const u32 ideltav = (has_deltav_field) ? pdl.get_field_idx<Tvec>("deltav") : 0;
1280 const u32 is_j = (has_s_j_field) ? pdl.get_field_idx<Tscal>("s_j") : 0;
1281
1282 auto &ghost_layout_ptr = storage.ghost_layout;
1283 shamrock::patch::PatchDataLayerLayout &ghost_layout = shambase::get_check_ref(ghost_layout_ptr);
1284 u32 ihpart_interf = ghost_layout.get_field_idx<Tscal>("hpart");
1285 u32 iuint_interf = ghost_layout.get_field_idx<Tscal>("uint");
1286 u32 ivxyz_interf = ghost_layout.get_field_idx<Tvec>("vxyz");
1287 u32 iomega_interf = ghost_layout.get_field_idx<Tscal>("omega");
1288
1289 const u32 iaxyz_interf
1290 = (solver_config.has_axyz_in_ghost()) ? ghost_layout.get_field_idx<Tvec>("axyz") : 0;
1291
1292 const u32 isoundspeed_interf
1293 = (has_soundspeed_field) ? ghost_layout.get_field_idx<Tscal>("soundspeed") : 0;
1294
1295 const u32 iB_interf = (has_B_field) ? ghost_layout.get_field_idx<Tvec>("B/rho") : 0;
1296 const u32 ipsi_interf = (has_psi_field) ? ghost_layout.get_field_idx<Tscal>("psi/ch") : 0;
1297 const u32 icurlB_interf = (has_curlB_field) ? ghost_layout.get_field_idx<Tvec>("curlB") : 0;
1298
1299 const u32 iepsilon_interf
1300 = (has_epsilon_field) ? ghost_layout.get_field_idx<Tscal>("epsilon") : 0;
1301 const u32 ideltav_interf = (has_deltav_field) ? ghost_layout.get_field_idx<Tvec>("deltav") : 0;
1302 const u32 is_j_interf = (has_s_j_field) ? ghost_layout.get_field_idx<Tscal>("s_j") : 0;
1303
1304 using InterfaceBuildInfos = typename sph::BasicSPHGhostHandler<Tvec>::InterfaceBuildInfos;
1305
1306 sph::BasicSPHGhostHandler<Tvec> &ghost_handle = storage.ghost_handler.get();
1308
1309 auto pdat_interf = ghost_handle.template build_interface_native<PatchDataLayer>(
1310 storage.ghost_patch_cache.get(),
1311 [&](u64 sender, u64, InterfaceBuildInfos binfo, sham::DeviceBuffer<u32> &buf_idx, u32 cnt) {
1312 PatchDataLayer pdat(ghost_layout_ptr);
1313
1314 pdat.reserve(cnt);
1315
1316 return pdat;
1317 });
1318
1319 ghost_handle.template modify_interface_native<PatchDataLayer>(
1320 storage.ghost_patch_cache.get(),
1321 pdat_interf,
1322 [&](u64 sender,
1323 u64,
1324 InterfaceBuildInfos binfo,
1325 sham::DeviceBuffer<u32> &buf_idx,
1326 u32 cnt,
1327 PatchDataLayer &pdat) {
1328 PatchDataLayer &sender_patch = scheduler().patch_data.get_pdat(sender);
1329 PatchDataField<Tscal> &sender_omega = omega.get(sender);
1330
1331 sender_patch.get_field<Tscal>(ihpart).append_subset_to(
1332 buf_idx, cnt, pdat.get_field<Tscal>(ihpart_interf));
1333 sender_patch.get_field<Tscal>(iuint).append_subset_to(
1334 buf_idx, cnt, pdat.get_field<Tscal>(iuint_interf));
1335
1336 if (solver_config.has_axyz_in_ghost()) {
1337 sender_patch.get_field<Tvec>(iaxyz).append_subset_to(
1338 buf_idx, cnt, pdat.get_field<Tvec>(iaxyz_interf));
1339 }
1340
1341 sender_patch.get_field<Tvec>(ivxyz).append_subset_to(
1342 buf_idx, cnt, pdat.get_field<Tvec>(ivxyz_interf));
1343
1344 sender_omega.append_subset_to(buf_idx, cnt, pdat.get_field<Tscal>(iomega_interf));
1345
1346 if (has_soundspeed_field) {
1347 sender_patch.get_field<Tscal>(isoundspeed)
1348 .append_subset_to(buf_idx, cnt, pdat.get_field<Tscal>(isoundspeed_interf));
1349 }
1350
1351 if (has_B_field) {
1352 sender_patch.get_field<Tvec>(iB_on_rho).append_subset_to(
1353 buf_idx, cnt, pdat.get_field<Tvec>(iB_interf));
1354 }
1355
1356 if (has_psi_field) {
1357 sender_patch.get_field<Tscal>(ipsi_on_ch)
1358 .append_subset_to(buf_idx, cnt, pdat.get_field<Tscal>(ipsi_interf));
1359 }
1360
1361 if (has_curlB_field) {
1362 sender_patch.get_field<Tvec>(icurlB).append_subset_to(
1363 buf_idx, cnt, pdat.get_field<Tvec>(icurlB_interf));
1364 }
1365
1366 if (has_epsilon_field) {
1367 sender_patch.get_field<Tscal>(iepsilon).append_subset_to(
1368 buf_idx, cnt, pdat.get_field<Tscal>(iepsilon_interf));
1369 }
1370
1371 if (has_deltav_field) {
1372 sender_patch.get_field<Tvec>(ideltav).append_subset_to(
1373 buf_idx, cnt, pdat.get_field<Tvec>(ideltav_interf));
1374 }
1375
1376 if (has_s_j_field) {
1377 sender_patch.get_field<Tscal>(is_j).append_subset_to(
1378 buf_idx, cnt, pdat.get_field<Tscal>(is_j_interf));
1379 }
1380 });
1381
1382 ghost_handle.template modify_interface_native<PatchDataLayer>(
1383 storage.ghost_patch_cache.get(),
1384 pdat_interf,
1385 [&](u64 sender,
1386 u64,
1387 InterfaceBuildInfos binfo,
1388 sham::DeviceBuffer<u32> &buf_idx,
1389 u32 cnt,
1390 PatchDataLayer &pdat) {
1391 if (sycl::length(binfo.offset_speed) > 0) {
1392 pdat.get_field<Tvec>(ivxyz_interf).apply_offset(binfo.offset_speed);
1393 }
1394 });
1395
1396 shambase::DistributedDataShared<PatchDataLayer> interf_pdat = ghost_handle.communicate_pdat(
1397 ghost_layout_ptr,
1398 std::move(pdat_interf),
1399 storage.exchange_gz_node,
1400 solver_config.show_ghost_zone_graph);
1401
1402 std::map<u64, u64> sz_interf_map;
1403 interf_pdat.for_each([&](u64 s, u64 r, PatchDataLayer &pdat_interf) {
1404 sz_interf_map[r] += pdat_interf.get_obj_cnt();
1405 });
1406
1407 storage.merged_patchdata_ghost.set(
1408 ghost_handle.template merge_native<PatchDataLayer, PatchDataLayer>(
1409 std::move(interf_pdat),
1411 PatchDataLayer pdat_new(ghost_layout_ptr);
1412
1413 u32 or_elem = pdat.get_obj_cnt();
1414 pdat_new.reserve(or_elem + sz_interf_map[p.id_patch]);
1415 u32 total_elements = or_elem;
1416
1417 PatchDataField<Tscal> &cur_omega = omega.get(p.id_patch);
1418
1419 pdat_new.get_field<Tscal>(ihpart_interf).insert(pdat.get_field<Tscal>(ihpart));
1420 pdat_new.get_field<Tscal>(iuint_interf).insert(pdat.get_field<Tscal>(iuint));
1421 pdat_new.get_field<Tvec>(ivxyz_interf).insert(pdat.get_field<Tvec>(ivxyz));
1422
1423 if (solver_config.has_axyz_in_ghost()) {
1424 pdat_new.get_field<Tvec>(iaxyz_interf).insert(pdat.get_field<Tvec>(iaxyz));
1425 }
1426
1427 pdat_new.get_field<Tscal>(iomega_interf).insert(cur_omega);
1428
1429 if (has_soundspeed_field) {
1430 pdat_new.get_field<Tscal>(isoundspeed_interf)
1431 .insert(pdat.get_field<Tscal>(isoundspeed));
1432 }
1433
1434 if (has_B_field) {
1435 pdat_new.get_field<Tvec>(iB_interf).insert(pdat.get_field<Tvec>(iB_on_rho));
1436 }
1437
1438 if (has_psi_field) {
1439 pdat_new.get_field<Tscal>(ipsi_interf)
1440 .insert(pdat.get_field<Tscal>(ipsi_on_ch));
1441 }
1442
1443 if (has_curlB_field) {
1444 pdat_new.get_field<Tvec>(icurlB_interf).insert(pdat.get_field<Tvec>(icurlB));
1445 }
1446
1447 if (has_epsilon_field) {
1448 pdat_new.get_field<Tscal>(iepsilon_interf)
1449 .insert(pdat.get_field<Tscal>(iepsilon));
1450 }
1451
1452 if (has_deltav_field) {
1453 pdat_new.get_field<Tvec>(ideltav_interf).insert(pdat.get_field<Tvec>(ideltav));
1454 }
1455
1456 if (has_s_j_field) {
1457 pdat_new.get_field<Tscal>(is_j_interf).insert(pdat.get_field<Tscal>(is_j));
1458 }
1459
1460 pdat_new.check_field_obj_cnt_match();
1461
1462 return pdat_new;
1463 },
1464 [](PatchDataLayer &pdat, PatchDataLayer &pdat_interf) {
1465 pdat.insert_elements(pdat_interf);
1466 }));
1467
1468 timer_interf.stop();
1469 storage.timings_details.interface += timer_interf.elapsed_sec();
1470}
1471
1472template<class Tvec, template<class> class Kern>
1474 storage.merged_patchdata_ghost.reset();
1475}
1476
1478// start artificial viscosity section //////////////////////////////////////////////////////////////
1480
1481template<class Tvec, template<class> class Kern>
1483
1484 sph::modules::UpdateViscosity<Tvec, Kern>(context, solver_config, storage)
1485 .update_artificial_viscosity(dt);
1486}
1487
1489// end artificial viscosity section ////////////////////////////////////////////////////////////////
1491
1492template<class Tvec, template<class> class Kern>
1497
1498template<class Tvec, template<class> class Kern>
1500 shambase::get_check_ref(storage.pressure).free_alloc();
1501 shambase::get_check_ref(storage.soundspeed).free_alloc();
1502}
1503
1504template<class Tvec, template<class> class Kern>
1506
1507 StackEntry stack_loc{};
1508
1509 using namespace shamrock;
1510 using namespace shamrock::patch;
1511 shamrock::SchedulerUtility utility(scheduler());
1512 PatchDataLayerLayout &pdl = scheduler().pdl_old();
1513
1514 bool has_B_field = solver_config.has_field_B_on_rho();
1515 bool has_psi_field = solver_config.has_field_psi_on_ch();
1516 bool has_epsilon_field = solver_config.dust_config.has_epsilon_field();
1517 bool has_deltav_field = solver_config.dust_config.has_deltav_field();
1518 bool has_s_j_field = solver_config.dust_config.has_s_j_field();
1519
1520 const u32 iduint = pdl.get_field_idx<Tscal>("duint");
1521 const u32 iaxyz = pdl.get_field_idx<Tvec>("axyz");
1522 const u32 idB_on_rho = (has_B_field) ? pdl.get_field_idx<Tvec>("dB/rho") : 0;
1523 const u32 idpsi_on_ch = (has_psi_field) ? pdl.get_field_idx<Tscal>("dpsi/ch") : 0;
1524
1525 shamlog_debug_ln("sph::BasicGas", "save old fields");
1526 storage.old_axyz.set(utility.save_field<Tvec>(iaxyz, "axyz_old"));
1527 storage.old_duint.set(utility.save_field<Tscal>(iduint, "duint_old"));
1528
1529 if (has_B_field) {
1530 storage.old_dB_on_rho.set(utility.save_field<Tvec>(idB_on_rho, "dB/rho_old"));
1531 }
1532 if (has_psi_field) {
1533 storage.old_dpsi_on_ch.set(utility.save_field<Tscal>(idpsi_on_ch, "dpsi/ch_old"));
1534 }
1535 if (has_epsilon_field) {
1536 storage.old_dtepsilon.set(
1537 utility.save_field<Tscal>(pdl.get_field_idx<Tscal>("dtepsilon"), "dtepsilon_old"));
1538 }
1539 if (has_deltav_field) {
1540 storage.old_dtdeltav.set(
1541 utility.save_field<Tvec>(pdl.get_field_idx<Tvec>("dtdeltav"), "dtdeltav_old"));
1542 }
1543 if (has_s_j_field) {
1544 storage.old_ds_j_dt.set(
1545 utility.save_field<Tscal>(pdl.get_field_idx<Tscal>("ds_j_dt"), "ds_j_dt_old"));
1546 }
1547}
1548
1549template<class Tvec, template<class> class Kern>
1551
1552 modules::UpdateDerivs<Tvec, Kern> derivs(context, solver_config, storage);
1553 derivs.update_derivs(dt_hydro);
1554
1555 modules::ExternalForces<Tvec, Kern> ext_forces(context, solver_config, storage);
1556 ext_forces.add_ext_forces();
1557}
1558
1559template<class Tvec, template<class> class Kern>
1561 return false;
1562}
1563
1564template<class Tvec, template<class> class Kern>
1566 modules::ComputeLoadBalanceValue<Tvec, Kern>(context, solver_config, storage)
1567 .update_load_balancing();
1568 scheduler().scheduler_step(false, false);
1569}
1570
1571template<class T>
1572void map_field_refs(
1573 PatchScheduler &sched, u32 field_idx, shamrock::solvergraph::FieldRefs<T> &refs) {
1574
1575 using namespace shamrock::solvergraph;
1576 using namespace shamrock::patch;
1577
1579 sched.for_each_patchdata_nonempty([&](const Patch p, PatchDataLayer &pdat) {
1580 auto &field = pdat.get_field<T>(field_idx);
1581 field_refs.add_obj(p.id_patch, std::ref(field));
1582 });
1583 refs.set_refs(field_refs);
1584}
1585
1586template<class T>
1587void map_field_refs_ext(
1588 PatchScheduler &sched,
1590 u32 field_idx,
1592
1593 using namespace shamrock::solvergraph;
1594 using namespace shamrock::patch;
1595
1597 sched.for_each_patchdata_nonempty([&](const Patch p, PatchDataLayer &pdat) {
1598 PatchDataLayer &mpdat = mpdats.get(p.id_patch);
1599 auto &field = mpdat.get_field<T>(field_idx);
1600 field_refs.add_obj(p.id_patch, std::ref(field));
1601 });
1602 refs.set_refs(field_refs);
1603}
1604
1605template<class T>
1606void map_field_refs_ext(
1607 PatchScheduler &sched,
1608 shamrock::ComputeField<T> &field_data,
1610
1611 using namespace shamrock::solvergraph;
1612 using namespace shamrock::patch;
1613
1615 sched.for_each_patchdata_nonempty([&](const Patch p, PatchDataLayer &pdat) {
1616 auto &field = field_data.get_field(p.id_patch);
1617 field_refs.add_obj(p.id_patch, std::ref(field));
1618 });
1619 refs.set_refs(field_refs);
1620}
1621
1622template<class Tvec, template<class> class Kern>
1624
1625 // has to be first since there is a barrier that may mess the other timers
1626 shamsys::SystemMetrics system_metrics_start = shamsys::get_system_metrics();
1627
1629 f64 mpi_timer_start = shamcomm::mpi::get_timer("total");
1630
1631 for (auto &callbacks : timestep_callbacks) {
1632 if (callbacks.step_begin_callback) {
1633 shambase::get_check_ref(callbacks.step_begin_callback)();
1634 }
1635 }
1636
1637 Tscal t_current = solver_config.get_time();
1638 Tscal dt = solver_config.get_dt_sph();
1639
1640 StackEntry stack_loc{};
1641
1642 if (shamcomm::world_rank() == 0) {
1644 shambase::format("---------------- t = {}, dt = {} ----------------", t_current, dt));
1645 }
1646
1647 shambase::Timer tstep;
1648 tstep.start();
1649
1650 // if(shamcomm::world_rank() == 0) std::cout << scheduler().dump_status() << std::endl;
1651 modules::ComputeLoadBalanceValue<Tvec, Kern>(context, solver_config, storage)
1652 .update_load_balancing();
1653 scheduler().scheduler_step(true, true);
1654 modules::ComputeLoadBalanceValue<Tvec, Kern>(context, solver_config, storage)
1655 .update_load_balancing();
1656 // if(shamcomm::world_rank() == 0) std::cout << scheduler().dump_status() << std::endl;
1657 scheduler().scheduler_step(false, false);
1658 // if(shamcomm::world_rank() == 0) std::cout << scheduler().dump_status() << std::endl;
1659
1661
1662 using namespace shamrock;
1663 using namespace shamrock::patch;
1664
1665 bool has_B_field = solver_config.has_field_B_on_rho();
1666 bool has_psi_field = solver_config.has_field_psi_on_ch();
1667 bool has_epsilon_field = solver_config.dust_config.has_epsilon_field();
1668 bool has_deltav_field = solver_config.dust_config.has_deltav_field();
1669 bool has_s_j_field = solver_config.dust_config.has_s_j_field();
1670
1671 PatchDataLayerLayout &pdl = scheduler().pdl_old();
1672
1673 const u32 ixyz = pdl.get_field_idx<Tvec>("xyz");
1674 const u32 ivxyz = pdl.get_field_idx<Tvec>("vxyz");
1675 const u32 iaxyz = pdl.get_field_idx<Tvec>("axyz");
1676 const u32 iuint = pdl.get_field_idx<Tscal>("uint");
1677 const u32 iduint = pdl.get_field_idx<Tscal>("duint");
1678 const u32 ihpart = pdl.get_field_idx<Tscal>("hpart");
1679 const u32 iB_on_rho = (has_B_field) ? pdl.get_field_idx<Tvec>("B/rho") : 0;
1680 const u32 idB_on_rho = (has_B_field) ? pdl.get_field_idx<Tvec>("dB/rho") : 0;
1681 const u32 ipsi_on_ch = (has_psi_field) ? pdl.get_field_idx<Tscal>("psi/ch") : 0;
1682 const u32 idpsi_on_ch = (has_psi_field) ? pdl.get_field_idx<Tscal>("dpsi/ch") : 0;
1683 const u32 iepsilon = (has_epsilon_field) ? pdl.get_field_idx<Tscal>("epsilon") : 0;
1684 const u32 idtepsilon = (has_epsilon_field) ? pdl.get_field_idx<Tscal>("dtepsilon") : 0;
1685 const u32 is_j = (has_s_j_field) ? pdl.get_field_idx<Tscal>("s_j") : 0;
1686 const u32 ids_j_dt = (has_s_j_field) ? pdl.get_field_idx<Tscal>("ds_j_dt") : 0;
1687 const u32 ideltav = (has_deltav_field) ? pdl.get_field_idx<Tvec>("deltav") : 0;
1688 const u32 idtdeltav = (has_deltav_field) ? pdl.get_field_idx<Tvec>("dtdeltav") : 0;
1689
1690 shamrock::SchedulerUtility utility(scheduler());
1691
1692 modules::SinkParticlesUpdate<Tvec, Kern> sink_update(context, solver_config, storage);
1693 modules::ExternalForces<Tvec, Kern> ext_forces(context, solver_config, storage);
1694
1695 sink_update.accrete_particles(dt);
1696 ext_forces.point_mass_accrete_particles();
1697
1698 sink_update.predictor_step(dt);
1699
1700 {
1701 // beginning of SolverGraph migration
1702
1703 using namespace shamrock::solvergraph;
1704
1705 SolverGraph &solver_graph = storage.solver_graph;
1706
1707 // change the graph inputs
1708 {
1709 solver_graph.get_edge_ref<IDataEdge<Tscal>>("dt").data = dt;
1710 solver_graph.get_edge_ref<IDataEdge<Tscal>>("dt_half").data = dt / 2.0;
1711 }
1712
1714 // Solver evaluation
1716
1717 shambase::get_check_ref(storage.solver_sequence).evaluate();
1718 }
1719
1720 sink_update.compute_ext_forces();
1721
1722 ext_forces.compute_ext_forces_indep_v();
1723
1724 gen_serial_patch_tree();
1725
1726 apply_position_boundary(t_current + dt);
1727
1728 u64 Npart_all = scheduler().get_total_obj_count();
1729
1730 if (solver_config.enable_particle_reordering
1731 && solve_logs.step_count % solver_config.particle_reordering_step_freq == 0) {
1732 logger::info_ln("SPH", "Reordering particles at step ", solve_logs.step_count);
1733 modules::ParticleReordering<Tvec, u_morton, Kern>(context, solver_config, storage)
1735 }
1736
1737 {
1738 // update part counts and spans since particles have been moved and thus
1739 // new patch can be non-empty/empty
1740 using namespace shamrock::solvergraph;
1741 SolverGraph &solver_graph = storage.solver_graph;
1742 solver_graph.get_node_ref_base("attach fields to scheduler").evaluate();
1743 }
1744
1745 sph_prestep(t_current, dt);
1746
1748
1749 // Here we will add self grav to the external forces indep of vel (this will be moved into a
1750 // sperate module later)
1751 if (solver_config.self_grav_config.is_sg_on()) {
1752
1753 auto constant_G = shamrock::solvergraph::IDataEdge<Tscal>::make_shared("", "");
1754
1757 constant_G.data = solver_config.get_constant_G();
1758 });
1759
1760 set_constant_G.set_edges(constant_G);
1761
1762 auto field_xyz = shamrock::solvergraph::FieldRefs<Tvec>::make_shared("", "");
1763
1765 [&](shamrock::solvergraph::FieldRefs<Tvec> &field_xyz_edge) {
1767 scheduler().for_each_patchdata_nonempty([&](const Patch p, PatchDataLayer &pdat) {
1768 auto &field = pdat.get_field<Tvec>(ixyz);
1769 field_xyz_refs.add_obj(p.id_patch, std::ref(field));
1770 });
1771 field_xyz_edge.set_refs(field_xyz_refs);
1772 });
1773 set_field_xyz.set_edges(field_xyz);
1774
1775 const u32 iaxyz_ext = pdl.get_field_idx<Tvec>("axyz_ext");
1776
1777 auto field_axyz_ext = shamrock::solvergraph::FieldRefs<Tvec>::make_shared("", "");
1778
1780 set_field_axyz_ext([&](shamrock::solvergraph::FieldRefs<Tvec> &field_axyz_ext_edge) {
1781 shamrock::solvergraph::DDPatchDataFieldRef<Tvec> field_axyz_ext_refs = {};
1782 scheduler().for_each_patchdata_nonempty([&](const Patch p, PatchDataLayer &pdat) {
1783 auto &field = pdat.get_field<Tvec>(iaxyz_ext);
1784 field_axyz_ext_refs.add_obj(p.id_patch, std::ref(field));
1785 });
1786 field_axyz_ext_edge.set_refs(field_axyz_ext_refs);
1787 });
1788 set_field_axyz_ext.set_edges(field_axyz_ext);
1789
1790 auto sizes = shamrock::solvergraph::Indexes<u32>::make_shared("", "");
1791
1794 sizes.indexes = {};
1795 scheduler().for_each_patchdata_nonempty([&](const Patch p, PatchDataLayer &pdat) {
1796 sizes.indexes.add_obj(p.id_patch, pdat.get_obj_cnt());
1797 });
1798 });
1799 set_sizes.set_edges(sizes);
1800
1801 auto gpart_mass = shamrock::solvergraph::IDataEdge<Tscal>::make_shared("", "");
1802
1805 gpart_mass.data = solver_config.gpart_mass;
1806 });
1807
1808 set_gpart_mass.set_edges(gpart_mass);
1809
1810 set_gpart_mass.evaluate();
1811 set_constant_G.evaluate();
1812 set_field_xyz.evaluate();
1813 set_field_axyz_ext.evaluate();
1814 set_sizes.evaluate();
1815
1816 Tscal eps_grav = shambase::get_check_ref(
1817 std::get_if<SelfGravConfig::SofteningPlummer>(
1818 &solver_config.self_grav_config.softening_mode))
1819 .epsilon;
1820
1821 if (solver_config.self_grav_config.is_none()) {
1822 // do nothing
1823 } else if (solver_config.self_grav_config.is_direct()) {
1824
1826 std::get_if<SelfGravConfig::Direct>(&solver_config.self_grav_config.config));
1827
1828 modules::SGDirectPlummer<Tvec> self_gravity_direct_node(
1829 eps_grav, direct_config.reference_mode);
1830 self_gravity_direct_node.set_edges(
1831 sizes, gpart_mass, constant_G, field_xyz, field_axyz_ext);
1832 self_gravity_direct_node.evaluate();
1833
1834 } else if (solver_config.self_grav_config.is_mm()) {
1835
1837 std::get_if<SelfGravConfig::MM>(&solver_config.self_grav_config.config));
1838
1839 auto run_sg_mm = [&](auto mm_order_tag) {
1840 constexpr u32 order = decltype(mm_order_tag)::value;
1841 modules::SGMMPlummer<Tvec, order> self_gravity_mm_node(
1842 eps_grav, mm_config.opening_angle, mm_config.reduction_level);
1843 self_gravity_mm_node.set_edges(
1844 sizes, gpart_mass, constant_G, field_xyz, field_axyz_ext);
1845 self_gravity_mm_node.evaluate();
1846 };
1847
1848 switch (mm_config.order) {
1849 case 1 : run_sg_mm(std::integral_constant<u32, 1>{}); break;
1850 case 2 : run_sg_mm(std::integral_constant<u32, 2>{}); break;
1851 case 3 : run_sg_mm(std::integral_constant<u32, 3>{}); break;
1852 case 4 : run_sg_mm(std::integral_constant<u32, 4>{}); break;
1853 case 5 : run_sg_mm(std::integral_constant<u32, 5>{}); break;
1855 }
1856
1857 } else if (solver_config.self_grav_config.is_fmm()) {
1858
1860 std::get_if<SelfGravConfig::FMM>(&solver_config.self_grav_config.config));
1861
1862 auto run_sg_fmm = [&](auto fmm_order_tag) {
1863 constexpr u32 order = decltype(fmm_order_tag)::value;
1864 modules::SGFMMPlummer<Tvec, order> self_gravity_mm_node(
1865 eps_grav, fmm_config.opening_angle, fmm_config.reduction_level);
1866 self_gravity_mm_node.set_edges(
1867 sizes, gpart_mass, constant_G, field_xyz, field_axyz_ext);
1868 self_gravity_mm_node.evaluate();
1869 };
1870
1871 switch (fmm_config.order) {
1872 case 1 : run_sg_fmm(std::integral_constant<u32, 1>{}); break;
1873 case 2 : run_sg_fmm(std::integral_constant<u32, 2>{}); break;
1874 case 3 : run_sg_fmm(std::integral_constant<u32, 3>{}); break;
1875 case 4 : run_sg_fmm(std::integral_constant<u32, 4>{}); break;
1876 case 5 : run_sg_fmm(std::integral_constant<u32, 5>{}); break;
1878 }
1879
1880 } else if (solver_config.self_grav_config.is_sfmm()) {
1881
1883 std::get_if<SelfGravConfig::SFMM>(&solver_config.self_grav_config.config));
1884
1885 auto run_sg_sfmm = [&](auto sfmm_order_tag) {
1886 constexpr u32 order = decltype(sfmm_order_tag)::value;
1887 modules::SGSFMMPlummer<Tvec, order> self_gravity_mm_node(
1888 eps_grav,
1889 sfmm_config.opening_angle,
1890 sfmm_config.leaf_lowering,
1891 sfmm_config.reduction_level);
1892 self_gravity_mm_node.set_edges(
1893 sizes, gpart_mass, constant_G, field_xyz, field_axyz_ext);
1894 self_gravity_mm_node.evaluate();
1895 };
1896
1897 switch (sfmm_config.order) {
1898 case 1 : run_sg_sfmm(std::integral_constant<u32, 1>{}); break;
1899 case 2 : run_sg_sfmm(std::integral_constant<u32, 2>{}); break;
1900 case 3 : run_sg_sfmm(std::integral_constant<u32, 3>{}); break;
1901 case 4 : run_sg_sfmm(std::integral_constant<u32, 4>{}); break;
1902 case 5 : run_sg_sfmm(std::integral_constant<u32, 5>{}); break;
1904 }
1905
1906 } else {
1908 "Self gravity config not supported, current state is : \n"
1909 + nlohmann::json{solver_config.self_grav_config}.dump(4));
1910 }
1911 }
1912
1913 sph::BasicSPHGhostHandler<Tvec> &ghost_handle = storage.ghost_handler.get();
1914 auto &merged_xyzh = storage.merged_xyzh.get();
1915 shambase::DistributedData<RTree> &trees = storage.merged_pos_trees.get();
1916 // ComputeField<Tscal> &omega = storage.omega.get();
1917
1919 = shambase::get_check_ref(storage.ghost_layout.get());
1920 u32 ihpart_interf = ghost_layout.get_field_idx<Tscal>("hpart");
1921 u32 iuint_interf = ghost_layout.get_field_idx<Tscal>("uint");
1922 u32 ivxyz_interf = ghost_layout.get_field_idx<Tvec>("vxyz");
1923 u32 iomega_interf = ghost_layout.get_field_idx<Tscal>("omega");
1924 u32 iB_on_rho_interf = (has_B_field) ? ghost_layout.get_field_idx<Tvec>("B/rho") : 0;
1925 u32 ipsi_on_rho_interf = (has_psi_field) ? ghost_layout.get_field_idx<Tscal>("psi/ch") : 0;
1926
1927 using RTreeField = RadixTreeField<Tscal>;
1929
1930 Tscal next_cfl = 0;
1931
1932 u32 corrector_iter_cnt = 0;
1933 bool need_rerun_corrector = false;
1934 do {
1935
1938
1939 if (corrector_iter_cnt == 50) {
1941 "the corrector has made over 50 loops, either their is a bug, either you are using "
1942 "a dt that is too large");
1943 }
1944
1945 // communicate fields
1947
1948 if (solver_config.has_field_alphaAV()) {
1949
1950 std::shared_ptr<shamrock::solvergraph::PatchDataLayerRefs> patchdatas
1951 = std::make_shared<shamrock::solvergraph::PatchDataLayerRefs>(
1952 "patchdata_layer_ref", "patchdata_layer_ref");
1953
1954 auto node_set_edge = scheduler().get_node_set_edge_patchdata_layer_refs();
1955 node_set_edge->set_edges(patchdatas);
1956 node_set_edge->evaluate();
1957
1959 scheduler().get_layout_ptr_old(), "alpha_AV");
1960 node_copy.set_edges(patchdatas, storage.alpha_av_updated);
1961 node_copy.evaluate();
1962 }
1963
1964 if (solver_config.has_field_dtdivv()) {
1965
1966 if (solver_config.combined_dtdiv_divcurlv_compute) {
1967 if (solver_config.has_field_dtdivv()) {
1968 sph::modules::DiffOperatorDtDivv<Tvec, Kern>(context, solver_config, storage)
1969 .update_dtdivv(true);
1970 }
1971 } else {
1972
1973 if (solver_config.has_field_divv()) {
1974 sph::modules::DiffOperators<Tvec, Kern>(context, solver_config, storage)
1975 .update_divv();
1976 }
1977
1978 if (solver_config.has_field_curlv()) {
1979 sph::modules::DiffOperators<Tvec, Kern>(context, solver_config, storage)
1980 .update_curlv();
1981 }
1982
1983 if (solver_config.has_field_dtdivv()) {
1984 sph::modules::DiffOperatorDtDivv<Tvec, Kern>(context, solver_config, storage)
1985 .update_dtdivv(false);
1986 }
1987 }
1988
1989 } else {
1990 if (solver_config.has_field_divv()) {
1991 sph::modules::DiffOperators<Tvec, Kern>(context, solver_config, storage)
1992 .update_divv();
1993 }
1994
1995 if (solver_config.has_field_curlv()) {
1996 sph::modules::DiffOperators<Tvec, Kern>(context, solver_config, storage)
1997 .update_curlv();
1998 }
1999 }
2000
2001 // if (solver_config.has_field_divB()) {
2002 // sph::modules::DiffOperatorsB<Tvec, Kern>(context, solver_config, storage)
2003 // .update_divB();
2004 // }
2005
2006 // if (solver_config.has_field_curlB()) {
2007 // sph::modules::DiffOperatorsB<Tvec, Kern>(context, solver_config, storage)
2008 // .update_curlB();
2009 // }
2011
2012 if (solver_config.has_field_alphaAV()) {
2013
2015 = shambase::get_check_ref(storage.alpha_av_updated);
2016
2017 using InterfaceBuildInfos =
2019
2020 shambase::Timer time_interf;
2021 time_interf.start();
2022
2023 auto field_interf = ghost_handle.template build_interface_native<PatchDataField<Tscal>>(
2024 storage.ghost_patch_cache.get(),
2025 [&](u64 sender,
2026 u64 /*receiver*/,
2027 InterfaceBuildInfos binfo,
2028 sham::DeviceBuffer<u32> &buf_idx,
2029 u32 cnt) -> PatchDataField<Tscal> {
2030 PatchDataField<Tscal> &sender_field = comp_field_send.get_field(sender);
2031
2032 return sender_field.make_new_from_subset(buf_idx, cnt);
2033 });
2034
2036 = ghost_handle.communicate_pdatfield(
2037 std::move(field_interf), 1, storage.exchange_gz_alpha);
2038
2040 = ghost_handle.template merge_native<PatchDataField<Tscal>, PatchDataField<Tscal>>(
2041 std::move(interf_pdat),
2043 PatchDataField<Tscal> &receiver_field
2044 = comp_field_send.get_field(p.id_patch);
2045 return receiver_field.duplicate();
2046 },
2047 [](PatchDataField<Tscal> &mpdat, PatchDataField<Tscal> &pdat_interf) {
2048 mpdat.insert(pdat_interf);
2049 });
2050
2051 time_interf.stop();
2052 storage.timings_details.interface += time_interf.elapsed_sec();
2053
2054 storage.alpha_av_ghost.set(std::move(merged_field));
2055 }
2056
2057 // compute pressure
2059
2060 constexpr bool debug_interfaces = false;
2061 if constexpr (debug_interfaces) {
2062
2063 if (solver_config.do_debug_dump) {
2064
2066 = storage.merged_patchdata_ghost.get();
2067
2068 scheduler().for_each_patchdata_nonempty([&](Patch cur_p, PatchDataLayer &pdat) {
2069 MergedPatchData &merged_patch = mpdat.get(cur_p.id_patch);
2070 PatchDataLayer &mpdat = merged_patch.pdat;
2071
2072 sycl::buffer<Tvec> &buf_xyz = shambase::get_check_ref(
2073 merged_xyzh.get(cur_p.id_patch).field_pos.get_buf());
2074 sycl::buffer<Tvec> &buf_vxyz = mpdat.get_field_buf_ref<Tvec>(ivxyz_interf);
2075 sycl::buffer<Tscal> &buf_hpart = mpdat.get_field_buf_ref<Tscal>(ihpart_interf);
2076
2077 u32 total_elements = shambase::get_check_ref(storage.part_counts_with_ghost)
2078 .indexes.get(cur_p.id_patch);
2079 SHAM_ASSERT(merged_patch.total_elements == total_elements);
2080
2082 total_elements,
2083 solver_config.gpart_mass,
2084
2085 buf_xyz,
2086 buf_hpart,
2087 buf_vxyz};
2088
2089 make_interface_debug_phantom_dump(info).gen_file().write_to_file(
2090 solver_config.debug_dump_filename);
2091 logger::raw_ln("writing : ", solver_config.debug_dump_filename);
2092 });
2093 }
2094 }
2095
2096 // compute force
2097 shamlog_debug_ln("sph::BasicGas", "compute force");
2098
2099 // save old acceleration
2101
2102 update_derivs(dt);
2103
2104 bool has_luminosity = solver_config.compute_luminosity;
2105
2106 if (has_luminosity) {
2107 const u32 iluminosity = pdl.get_field_idx<Tscal>("luminosity");
2108
2109 shambase::get_check_ref(storage.hpart_with_ghosts)
2110 .set_refs(storage.merged_xyzh.get()
2111 .template map<std::reference_wrapper<PatchDataField<Tscal>>>(
2112 [&](u64 id, shamrock::patch::PatchDataLayer &mpdat) {
2113 return std::ref(mpdat.get_field<Tscal>(
2114 1)); // hpart is at index 1 in merged_xyzh
2115 }));
2116
2117 auto uint_with_ghost = shamrock::solvergraph::FieldRefs<Tscal>::make_shared("", "");
2118
2119 shambase::get_check_ref(storage.hpart_with_ghosts)
2120 .set_refs(storage.merged_xyzh.get()
2121 .template map<std::reference_wrapper<PatchDataField<Tscal>>>(
2122 [&](u64 id, shamrock::patch::PatchDataLayer &mpdat) {
2123 return std::ref(mpdat.get_field<Tscal>(1));
2124 }));
2125
2127 set_uint_with_ghost_refs(
2128 [&](shamrock::solvergraph::FieldRefs<Tscal> &field_uint_with_ghost_edge) {
2130 = storage.merged_patchdata_ghost.get();
2131
2132 shamrock::solvergraph::DDPatchDataFieldRef<Tscal> field_uint_with_ghost_refs
2133 = {};
2134
2135 scheduler().for_each_patchdata_nonempty(
2136 [&](const Patch p, PatchDataLayer &pdat) {
2137 PatchDataLayer &mpdat = mpdats.get(p.id_patch);
2138
2139 auto &field = mpdat.get_field<Tscal>(iuint_interf);
2140 field_uint_with_ghost_refs.add_obj(p.id_patch, std::ref(field));
2141 });
2142
2143 field_uint_with_ghost_edge.set_refs(field_uint_with_ghost_refs);
2144 });
2145
2146 set_uint_with_ghost_refs.set_edges(uint_with_ghost);
2147
2148 auto luminosity = shamrock::solvergraph::FieldRefs<Tscal>::make_shared("", "");
2149
2151 set_luminosity_refs(
2152 [&](shamrock::solvergraph::FieldRefs<Tscal> &field_luminosity_edge) {
2154 = storage.merged_patchdata_ghost.get();
2155
2157 = {};
2158
2159 scheduler().for_each_patchdata_nonempty(
2160 [&](const Patch p, PatchDataLayer &pdat) {
2161 auto &field = pdat.get_field<Tscal>(iluminosity);
2162 field_luminosity_refs.add_obj(p.id_patch, std::ref(field));
2163 });
2164 field_luminosity_edge.set_refs(field_luminosity_refs);
2165 });
2166
2167 set_luminosity_refs.set_edges(luminosity);
2168
2169 set_uint_with_ghost_refs.evaluate();
2170 set_luminosity_refs.evaluate();
2171
2172 Tscal alpha_u = solver_config.artif_viscosity.get_alpha_u().value();
2173
2175 solver_config.gpart_mass, alpha_u};
2176
2177 compute_luminosity.set_edges(
2178 storage.part_counts,
2179 storage.neigh_cache,
2180 storage.positions_with_ghosts,
2181 storage.hpart_with_ghosts,
2182 storage.omega,
2183 uint_with_ghost,
2184 storage.pressure,
2185 luminosity);
2186
2187 compute_luminosity.evaluate();
2188 }
2189
2190 modules::ConservativeCheck<Tvec, Kern> cv_check(context, solver_config, storage);
2191 cv_check.check_conservation();
2192
2193 ComputeField<Tscal> vepsilon_v_sq
2194 = utility.make_compute_field<Tscal>("vmean epsilon_v^2", 1);
2195 ComputeField<Tscal> uepsilon_u_sq
2196 = utility.make_compute_field<Tscal>("umean epsilon_u^2", 1);
2197
2198 // corrector
2199 shamlog_debug_ln("sph::BasicGas", "leapfrog corrector");
2200 utility.fields_leapfrog_corrector<Tvec>(
2201 ivxyz, iaxyz, storage.old_axyz.get(), vepsilon_v_sq, dt / 2);
2202 utility.fields_leapfrog_corrector<Tscal>(
2203 iuint, iduint, storage.old_duint.get(), uepsilon_u_sq, dt / 2);
2204
2205 if (solver_config.has_field_B_on_rho()) {
2206 ComputeField<Tscal> BOR_epsilon_BOR_sq
2207 = utility.make_compute_field<Tscal>("B/rho epsilon_B/rho^2", 1);
2208 utility.fields_leapfrog_corrector<Tvec>(
2209 iB_on_rho, idB_on_rho, storage.old_dB_on_rho.get(), BOR_epsilon_BOR_sq, dt / 2);
2210 }
2211 if (solver_config.has_field_B_on_rho()) {
2212 ComputeField<Tscal> POC_epsilon_POC_sq
2213 = utility.make_compute_field<Tscal>("psi/ch epsilon_psi/ch^2", 1);
2214 utility.fields_leapfrog_corrector<Tscal>(
2215 ipsi_on_ch, idpsi_on_ch, storage.old_dpsi_on_ch.get(), POC_epsilon_POC_sq, dt / 2);
2216 }
2217
2218 if (solver_config.dust_config.has_epsilon_field()) {
2219 ComputeField<Tscal> epsilon_epsilon_sq
2220 = utility.make_compute_field<Tscal>("epsilon epsilon^2", 1);
2221 utility.fields_leapfrog_corrector<Tscal>(
2222 iepsilon, idtepsilon, storage.old_dtepsilon.get(), epsilon_epsilon_sq, dt / 2);
2223 }
2224
2225 if (solver_config.dust_config.has_deltav_field()) {
2226 ComputeField<Tscal> epsilon_deltav_sq
2227 = utility.make_compute_field<Tscal>("deltav deltav^2", 1);
2228 utility.fields_leapfrog_corrector<Tvec>(
2229 ideltav, idtdeltav, storage.old_dtdeltav.get(), epsilon_deltav_sq, dt / 2);
2230 }
2231
2232 if (solver_config.dust_config.has_s_j_field()) {
2233 ComputeField<Tscal> s_j_s_j_sq = utility.make_compute_field<Tscal>(
2234 "s_j s_j^2", solver_config.dust_config.get_dust_nvar());
2235 utility.fields_leapfrog_corrector<Tscal>(
2236 is_j, ids_j_dt, storage.old_ds_j_dt.get(), s_j_s_j_sq, dt / 2);
2237 }
2238
2239 storage.old_axyz.reset();
2240 storage.old_duint.reset();
2241 if (solver_config.has_field_B_on_rho()) {
2242 storage.old_dB_on_rho.reset();
2243 }
2244 if (solver_config.has_field_B_on_rho()) {
2245 storage.old_dpsi_on_ch.reset();
2246 }
2247
2248 if (solver_config.dust_config.has_epsilon_field()) {
2249 storage.old_dtepsilon.reset();
2250 }
2251
2252 if (solver_config.dust_config.has_deltav_field()) {
2253 storage.old_dtdeltav.reset();
2254 }
2255
2256 if (solver_config.dust_config.has_s_j_field()) {
2257 storage.old_ds_j_dt.reset();
2258 }
2259
2260 Tscal rank_veps_v = sycl::sqrt(vepsilon_v_sq.compute_rank_max());
2262 // compute means //////////////////////////
2264
2265 Tscal sum_vsq = utility.compute_rank_dot_sum<Tvec>(ivxyz);
2266
2267 Tscal vmean_sq = shamalgs::collective::allreduce_sum(sum_vsq) / Tscal(Npart_all);
2268
2269 Tscal vmean = sycl::sqrt(vmean_sq);
2270
2271 Tscal rank_eps_v = rank_veps_v / vmean;
2272
2273 if (vmean <= 0) {
2274 rank_eps_v = 0;
2275 }
2276
2277 Tscal eps_v = shamalgs::collective::allreduce_max(rank_eps_v);
2278
2279 shamlog_debug_ln("BasicGas", "epsilon v :", eps_v);
2280
2281 if (eps_v > 1e-2) {
2282 if (shamcomm::world_rank() == 0) {
2284 "BasicGasSPH",
2285 shambase::format(
2286 "the corrector tolerance are broken the step will "
2287 "be re rerunned\n eps_v = {}",
2288 eps_v));
2289 }
2290 need_rerun_corrector = true;
2291 solver_config.time_state.cfl_multiplier /= 2;
2292
2293 // logger::info_ln("rerun corrector ...");
2294 } else {
2295 need_rerun_corrector = false;
2296 }
2297
2298 if (!need_rerun_corrector) {
2299
2300 sink_update.corrector_step(dt);
2301
2302 // write back alpha av field
2303 if (solver_config.has_field_alphaAV()) {
2304
2305 const u32 ialpha_AV = pdl.get_field_idx<Tscal>("alpha_AV");
2306 shamrock::solvergraph::Field<Tscal> &alpha_av_updated
2307 = shambase::get_check_ref(storage.alpha_av_updated);
2308
2309 scheduler().for_each_patchdata_nonempty([&](Patch cur_p, PatchDataLayer &pdat) {
2310 sham::DeviceBuffer<Tscal> &buf_alpha_av
2311 = pdat.get_field<Tscal>(ialpha_AV).get_buf();
2312 sham::DeviceBuffer<Tscal> &buf_alpha_av_updated
2313 = alpha_av_updated.get_field(cur_p.id_patch).get_buf();
2314
2315 auto &q = shamsys::instance::get_compute_scheduler().get_queue();
2316 sham::EventList depends_list;
2317
2318 auto alpha_av = buf_alpha_av.get_write_access(depends_list);
2319 auto alpha_av_updated = buf_alpha_av_updated.get_read_access(depends_list);
2320
2321 auto e = q.submit(depends_list, [&](sycl::handler &cgh) {
2322 shambase::parallel_for(
2323 cgh, pdat.get_obj_cnt(), "write back alpha_av", [=](i32 id_a) {
2324 alpha_av[id_a] = alpha_av_updated[id_a];
2325 });
2326 });
2327
2328 buf_alpha_av.complete_event_state(e);
2329 buf_alpha_av_updated.complete_event_state(e);
2330 });
2331 }
2332
2333 shamlog_debug_ln("BasicGas", "computing next CFL");
2334
2335 // Update element counts
2336 shambase::get_check_ref(storage.part_counts).indexes
2337 = storage.merged_xyzh.get().template map<u32>(
2338 [&](u64 id, shamrock::patch::PatchDataLayer &mpdat) {
2339 return scheduler().patch_data.get_pdat(id).get_obj_cnt();
2340 });
2341
2342 std::shared_ptr<shamrock::solvergraph::Field<Tscal>> vsig_max_dt
2343 = std::make_shared<shamrock::solvergraph::Field<Tscal>>(
2344 1, "vsig_a", "v_{\\rm sig}");
2345 vsig_max_dt->ensure_sizes(shambase::get_check_ref(storage.part_counts).indexes);
2346
2347 std::shared_ptr<shamrock::solvergraph::Field<Tscal>> vclean_dt;
2348 if (has_psi_field) {
2349 vclean_dt = std::make_shared<shamrock::solvergraph::Field<Tscal>>(
2350 1, "vclean_a", "v_{\\rm clean}");
2351 vclean_dt->ensure_sizes(shambase::get_check_ref(storage.part_counts).indexes);
2352 }
2353
2355 = storage.merged_patchdata_ghost.get();
2356
2357 scheduler().for_each_patchdata_nonempty([&](Patch cur_p, PatchDataLayer &pdat) {
2358 PatchDataLayer &mpdat = mpdats.get(cur_p.id_patch);
2359
2361 = merged_xyzh.get(cur_p.id_patch).template get_field_buf_ref<Tvec>(0);
2362 sham::DeviceBuffer<Tvec> &buf_vxyz = mpdat.get_field_buf_ref<Tvec>(ivxyz_interf);
2363 sham::DeviceBuffer<Tscal> &buf_hpart
2364 = mpdat.get_field_buf_ref<Tscal>(ihpart_interf);
2365 sham::DeviceBuffer<Tscal> &buf_uint = mpdat.get_field_buf_ref<Tscal>(iuint_interf);
2366 sham::DeviceBuffer<Tscal> &buf_pressure
2367 = shambase::get_check_ref(storage.pressure).get_field(cur_p.id_patch).get_buf();
2368 sham::DeviceBuffer<Tscal> &cs_buf = shambase::get_check_ref(storage.soundspeed)
2369 .get_field(cur_p.id_patch)
2370 .get_buf();
2371
2372 sham::DeviceBuffer<Tscal> &vsig_buf = vsig_max_dt->get_buf(cur_p.id_patch);
2373
2374 sycl::range range_npart{pdat.get_obj_cnt()};
2375
2376 tree::ObjectCache &pcache
2377 = shambase::get_check_ref(storage.neigh_cache).get_cache(cur_p.id_patch);
2378
2380
2381 {
2382
2383 auto &q = shamsys::instance::get_compute_scheduler().get_queue();
2384 sham::EventList depends_list;
2385
2386 auto xyz = buf_xyz.get_read_access(depends_list);
2387 auto vxyz = buf_vxyz.get_read_access(depends_list);
2388 auto hpart = buf_hpart.get_read_access(depends_list);
2389 auto u = buf_uint.get_read_access(depends_list);
2390 auto pressure = buf_pressure.get_read_access(depends_list);
2391 auto cs = cs_buf.get_read_access(depends_list);
2392 auto vsig = vsig_buf.get_write_access(depends_list);
2393 auto particle_looper_ptrs = pcache.get_read_access(depends_list);
2394
2395 NamedStackEntry tmppp{"compute vsig"};
2396 auto e = q.submit(depends_list, [&](sycl::handler &cgh) {
2397 const Tscal pmass = solver_config.gpart_mass;
2398 const Tscal alpha_u = 1.0;
2399 const Tscal alpha_AV = 1.0;
2400 const Tscal beta_AV = 2.0;
2401
2402 tree::ObjectCacheIterator particle_looper(particle_looper_ptrs);
2403
2404 constexpr Tscal Rker2 = Kernel::Rkern * Kernel::Rkern;
2405
2406 shambase::parallel_for(
2407 cgh, pdat.get_obj_cnt(), "compute vsig", [=](i32 id_a) {
2408 using namespace shamrock::sph;
2409
2410 Tvec sum_axyz = {0, 0, 0};
2411 Tscal sum_du_a = 0;
2412 Tscal h_a = hpart[id_a];
2413
2414 Tvec xyz_a = xyz[id_a];
2415 Tvec vxyz_a = vxyz[id_a];
2416
2417 Tscal rho_a = rho_h(pmass, h_a, Kernel::hfactd);
2418 Tscal rho_a_sq = rho_a * rho_a;
2419 Tscal rho_a_inv = 1. / rho_a;
2420
2421 Tscal P_a = pressure[id_a];
2422
2423 const Tscal u_a = u[id_a];
2424
2425 Tscal cs_a = cs[id_a];
2426
2427 Tscal vsig_max = 0;
2428
2429 particle_looper.for_each_object(id_a, [&](u32 id_b) {
2430 // compute only omega_a
2431 Tvec dr = xyz_a - xyz[id_b];
2432 Tscal rab2 = sycl::dot(dr, dr);
2433 Tscal h_b = hpart[id_b];
2434
2435 if (rab2 > h_a * h_a * Rker2 && rab2 > h_b * h_b * Rker2) {
2436 return;
2437 }
2438
2439 Tscal rab = sycl::sqrt(rab2);
2440 Tvec vxyz_b = vxyz[id_b];
2441 Tvec v_ab = vxyz_a - vxyz_b;
2442 const Tscal u_b = u[id_b];
2443
2444 Tvec r_ab_unit = dr / rab;
2445
2446 if (rab < 1e-9) {
2447 r_ab_unit = {0, 0, 0};
2448 }
2449
2450 Tscal rho_b = rho_h(pmass, h_b, Kernel::hfactd);
2451 Tscal P_b = pressure[id_b];
2452 Tscal cs_b = cs[id_b];
2453 Tscal v_ab_r_ab = sycl::dot(v_ab, r_ab_unit);
2454 Tscal abs_v_ab_r_ab = sycl::fabs(v_ab_r_ab);
2455
2457 // internal energy update
2458 // scalar : f32 | vector : f32_3
2459 const Tscal alpha_a = alpha_AV;
2460 const Tscal alpha_b = alpha_AV;
2461
2462 Tscal vsig_a = alpha_a * cs_a + beta_AV * abs_v_ab_r_ab;
2463
2464 vsig_max = sycl::fmax(vsig_max, vsig_a);
2465 });
2466
2467 vsig[id_a] = vsig_max;
2468 });
2469 });
2470
2471 if (has_psi_field) {
2472 NamedStackEntry tmppp{"compute vclean"};
2473 Tscal const mu_0 = solver_config.get_constant_mu_0();
2474 sham::DeviceBuffer<Tscal> &vclean_buf = vclean_dt->get_buf(cur_p.id_patch);
2475
2476 Tvec *B_on_rho = mpdat.get_field_buf_ref<Tvec>(iB_on_rho_interf)
2477 .get_write_access(depends_list);
2478
2479 auto vclean = vclean_buf.get_write_access(depends_list);
2480
2481 auto e = q.submit(depends_list, [&](sycl::handler &cgh) {
2482 const Tscal pmass = solver_config.gpart_mass;
2483
2484 tree::ObjectCacheIterator particle_looper(particle_looper_ptrs);
2485
2486 constexpr Tscal Rker2 = Kernel::Rkern * Kernel::Rkern;
2487
2488 shambase::parallel_for(
2489 cgh, pdat.get_obj_cnt(), "compute vclean", [=](i32 id_a) {
2490 using namespace shamrock::sph;
2491
2492 Tscal h_a = hpart[id_a];
2493 Tscal rho_a = rho_h(pmass, h_a, Kernel::hfactd);
2494 const Tscal u_a = u[id_a];
2495 Tscal cs_a = cs[id_a];
2496 Tvec B_a = B_on_rho[id_a] * rho_a;
2497
2498 Tscal vclean_a = shamphys::MHD_physics<Tvec, Tscal>::v_shock(
2499 cs_a, B_a, rho_a, mu_0);
2500
2501 vclean[id_a] = vclean_a;
2502 });
2503 });
2504 mpdat.get_field_buf_ref<Tvec>(iB_on_rho_interf).complete_event_state(e);
2505 vclean_buf.complete_event_state(e);
2506 };
2507
2508 buf_xyz.complete_event_state(e);
2509 buf_vxyz.complete_event_state(e);
2510 buf_hpart.complete_event_state(e);
2511 buf_uint.complete_event_state(e);
2512 buf_pressure.complete_event_state(e);
2513 cs_buf.complete_event_state(e);
2514 vsig_buf.complete_event_state(e);
2515
2516 sham::EventList resulting_events;
2517 resulting_events.add_event(e);
2518 pcache.complete_event_state(resulting_events);
2519 }
2520 });
2521
2522 std::shared_ptr<shamrock::solvergraph::Field<Tscal>> cfl_dt
2523 = std::make_shared<shamrock::solvergraph::Field<Tscal>>(
2524 1, "cfl_dt", "\\Delta t_{cfl}");
2525 cfl_dt->ensure_sizes(shambase::get_check_ref(storage.part_counts).indexes);
2526
2527 std::shared_ptr<shamrock::solvergraph::FieldRefs<Tvec>> axyz_refs
2528 = std::make_shared<shamrock::solvergraph::FieldRefs<Tvec>>("axyz", "\\mathbf{a}");
2529 std::shared_ptr<shamrock::solvergraph::FieldRefs<Tscal>> hpart_refs
2530 = std::make_shared<shamrock::solvergraph::FieldRefs<Tscal>>("hpart", "h");
2531
2532 map_field_refs(scheduler(), iaxyz, *axyz_refs);
2533 map_field_refs_ext(scheduler(), mpdats, ihpart_interf, *hpart_refs);
2534
2535 auto &q = shamsys::instance::get_compute_scheduler().get_queue();
2536
2537 auto reset_dt_part_field = [&]() {
2538 if (solver_config.should_save_dt_to_fields()) {
2539 const u32 idt_part = pdl.get_field_idx<Tscal>("dt_part");
2540 scheduler().for_each_patchdata_nonempty([&](Patch cur_p, PatchDataLayer &pdat) {
2541 sham::DeviceBuffer<Tscal> &buf_dt_part
2542 = pdat.get_field_buf_ref<Tscal>(idt_part);
2543 buf_dt_part.fill(shambase::get_infty<Tscal>());
2544 });
2545 }
2546 };
2547
2548 auto save_dt_min_to_dt_part = [&]() {
2549 if (solver_config.should_save_dt_to_fields()) {
2550 const u32 idt_part = pdl.get_field_idx<Tscal>("dt_part");
2551 scheduler().for_each_patchdata_nonempty([&](Patch cur_p, PatchDataLayer &pdat) {
2552 sham::DeviceBuffer<Tscal> &buf_dt_part
2553 = pdat.get_field_buf_ref<Tscal>(idt_part);
2554 sham::DeviceBuffer<Tscal> &buf_dt = cfl_dt->get_buf(cur_p.id_patch);
2555
2557 q,
2558 sham::MultiRef{buf_dt},
2559 sham::MultiRef{buf_dt_part},
2560 pdat.get_obj_cnt(),
2561 [](u32 id_a, const Tscal *dt, Tscal *dt_part) {
2562 dt_part[id_a] = sycl::min(dt_part[id_a], dt[id_a]);
2563 });
2564 });
2565 }
2566 };
2567
2568 // reset the cfl_dt field
2569 auto reset_cfl_dt = [&]() {
2570 scheduler().for_each_patchdata_nonempty([&](Patch cur_p, PatchDataLayer &pdat) {
2571 cfl_dt->get_buf(cur_p.id_patch).fill(shambase::get_infty<Tscal>());
2572 });
2573 };
2574
2575 Tscal C_cour
2576 = solver_config.cfl_config.cfl_cour * solver_config.time_state.cfl_multiplier;
2577 Tscal C_force
2578 = solver_config.cfl_config.cfl_force * solver_config.time_state.cfl_multiplier;
2579 Tscal eta_phi = solver_config.cfl_config.eta_sink;
2580
2581 std::shared_ptr<shamrock::solvergraph::ScalarEdge<Tscal>> C_cour_edge
2582 = std::make_shared<shamrock::solvergraph::ScalarEdge<Tscal>>("C_cour", "C_{cour}");
2583 C_cour_edge->value = C_cour;
2584 std::shared_ptr<shamrock::solvergraph::ScalarEdge<Tscal>> C_force_edge
2585 = std::make_shared<shamrock::solvergraph::ScalarEdge<Tscal>>(
2586 "C_force", "C_{force}");
2587 C_force_edge->value = C_force;
2588 std::shared_ptr<shamrock::solvergraph::ScalarEdge<Tscal>> eta_phi_edge
2589 = std::make_shared<shamrock::solvergraph::ScalarEdge<Tscal>>(
2590 "eta_phi", "\\eta_{\\phi}");
2591 eta_phi_edge->value = eta_phi;
2592
2593 std::shared_ptr<ComputeCFLCourant<Tscal>> compute_cfl_courant
2594 = std::make_shared<ComputeCFLCourant<Tscal>>();
2595 compute_cfl_courant->set_edges(
2596 storage.part_counts, C_cour_edge, hpart_refs, vsig_max_dt, cfl_dt);
2597
2598 std::shared_ptr<ComputeCFLForce<Tvec>> compute_cfl_force
2599 = std::make_shared<ComputeCFLForce<Tvec>>();
2600 compute_cfl_force->set_edges(
2601 storage.part_counts, C_force_edge, hpart_refs, axyz_refs, cfl_dt);
2602
2603 std::shared_ptr<ComputeCFLDivBCleaning<Tscal>> compute_cfl_divB_cleaning;
2604 if (has_psi_field) {
2605 compute_cfl_divB_cleaning = std::make_shared<ComputeCFLDivBCleaning<Tscal>>();
2606 compute_cfl_divB_cleaning->set_edges(
2607 storage.part_counts, C_cour_edge, hpart_refs, vclean_dt, cfl_dt);
2608 }
2609
2610 bool show_cfl_detail = solver_config.show_cfl_detail;
2611 std::vector<std::pair<std::string, Tscal>> cfl_detail;
2612
2613 auto save_cfl_detail = [&](const char *key) {
2614 if (show_cfl_detail) {
2615 save_dt_min_to_dt_part();
2616 cfl_detail.push_back(
2617 {std::string(key), cfl_dt->get_native().compute_rank_min()});
2618 reset_cfl_dt();
2619 }
2620 };
2621
2622 reset_dt_part_field();
2623 reset_cfl_dt();
2624
2625 compute_cfl_courant->evaluate();
2626 save_cfl_detail("courant");
2627
2628 compute_cfl_force->evaluate();
2629 save_cfl_detail("force");
2630
2631 if (has_psi_field) {
2632 compute_cfl_divB_cleaning->evaluate();
2633 save_cfl_detail("divB_cleaning");
2634 }
2635
2636 if (!show_cfl_detail) {
2637 save_dt_min_to_dt_part();
2638 cfl_detail.push_back({"all SPH", cfl_dt->get_native().compute_rank_min()});
2639 }
2640
2641 if (!storage.sinks.is_empty()) {
2642 // sink sink CFL
2643
2644 Tscal sink_sink_cfl = shambase::get_infty<Tscal>();
2645
2646 Tscal G = solver_config.get_constant_G();
2647
2648 std::vector<SinkParticle<Tvec>> &sink_parts = storage.sinks.get();
2649
2650 for (u32 i = 0; i < sink_parts.size(); i++) {
2651 SinkParticle<Tvec> &s_i = sink_parts[i];
2652 Tscal sink_sink_cfl_i = shambase::get_infty<Tscal>();
2653
2654 Tvec f_i = s_i.ext_acceleration;
2655
2656 Tscal grad_phi_i_sq = sham::dot(f_i, f_i); // m^2.s^-4
2657
2658 if (grad_phi_i_sq == 0) {
2659 continue;
2660 }
2661
2662 for (u32 j = 0; j < sink_parts.size(); j++) {
2663 SinkParticle<Tvec> &s_j = sink_parts[j];
2664
2665 if (i == j) {
2666 continue;
2667 }
2668
2669 Tvec rij = s_i.pos - s_j.pos;
2670 Tscal rij_scal = sycl::length(rij);
2671
2672 Tscal phi_ij = G * s_j.mass / rij_scal; // J / kg = m^2.s^-2
2673 Tscal term_ij = sham::abs(phi_ij) / grad_phi_i_sq; // s^2
2674 Tscal dt_ij = C_force * eta_phi * sycl::sqrt(term_ij); // s
2675
2676 sink_sink_cfl_i = sham::min(sink_sink_cfl_i, dt_ij);
2677 }
2678
2679 sink_sink_cfl = sham::min(sink_sink_cfl, sink_sink_cfl_i);
2680 }
2681
2682 cfl_detail.push_back({"sink_sink", sink_sink_cfl});
2683 }
2684
2685 Tscal rank_dt = shambase::get_infty<Tscal>();
2686 for (auto &[key, value] : cfl_detail) {
2687 rank_dt = sham::min(rank_dt, value);
2688 }
2689
2690 if (show_cfl_detail) {
2691 for (auto &[key, value] : cfl_detail) {
2692 value = shamalgs::collective::allreduce_min(value);
2693 }
2694
2695 if (shamcomm::world_rank() == 0) {
2696 shambase::table table(2);
2697 table.add_double_rule();
2698 table.add_data({"key", "value"}, shambase::table::center);
2699 table.add_double_rule();
2700 for (auto &[key, value] : cfl_detail) {
2701 table.add_data(
2702 {key, shambase::format("{:.2e}", value)}, shambase::table::right);
2703 }
2704 table.add_rule();
2705 logger::info_ln("sph::Model", "CFL detail :", table.render());
2706 }
2707 }
2708
2709 next_cfl = shamalgs::collective::allreduce_min(rank_dt);
2710
2711 if (shamcomm::world_rank() == 0) {
2713 "sph::Model",
2714 "cfl dt =",
2715 next_cfl,
2716 "cfl multiplier :",
2717 solver_config.time_state.cfl_multiplier);
2718 }
2719
2720 // this should not be needed idealy, but we need the pressure on the ghosts and
2721 // we don't want to communicate it as it can be recomputed from the other fields
2722 // hence we copy the soundspeed at the end of the step to a field in the patchdata
2723 if (solver_config.has_field_soundspeed()) {
2724
2725 const u32 isoundspeed = pdl.get_field_idx<Tscal>("soundspeed");
2726
2727 scheduler().for_each_patchdata_nonempty([&](Patch cur_p, PatchDataLayer &pdat) {
2728 sham::DeviceBuffer<Tscal> &buf_cs = pdat.get_field_buf_ref<Tscal>(isoundspeed);
2729 sham::DeviceBuffer<Tscal> &buf_cs_in
2730 = shambase::get_check_ref(storage.soundspeed)
2731 .get_field(cur_p.id_patch)
2732 .get_buf();
2733
2734 sycl::range range_npart{pdat.get_obj_cnt()};
2735
2737
2738 auto &q = shamsys::instance::get_compute_scheduler().get_queue();
2739 sham::EventList depends_list;
2740
2741 auto cs_in = buf_cs_in.get_read_access(depends_list);
2742 auto cs = buf_cs.get_write_access(depends_list);
2743
2744 auto e = q.submit(depends_list, [&](sycl::handler &cgh) {
2745 const Tscal pmass = solver_config.gpart_mass;
2746
2747 cgh.parallel_for(
2748 sycl::range<1>{pdat.get_obj_cnt()}, [=](sycl::item<1> item) {
2749 cs[item] = cs_in[item];
2750 });
2751 });
2752
2753 buf_cs_in.complete_event_state(e);
2754 buf_cs.complete_event_state(e);
2755 });
2756 }
2757
2758 } // if (!need_rerun_corrector) {
2759
2760 corrector_iter_cnt++;
2761
2762 if (solver_config.has_field_alphaAV()) {
2763 storage.alpha_av_ghost.reset();
2764 }
2765 } while (need_rerun_corrector);
2766
2767 reset_merge_ghosts_fields();
2768 reset_eos_fields();
2769
2770 // if delta too big jump to compute force
2771
2772 tstep.stop();
2773
2774 for (auto it = timestep_callbacks.rbegin(); it != timestep_callbacks.rend(); ++it) {
2775 if (it->step_end_callback) {
2776 shambase::get_check_ref(it->step_end_callback)();
2777 }
2778 }
2779
2780 f64 delta_mpi_timer = shamcomm::mpi::get_timer("total") - mpi_timer_start;
2782
2784 shamsys::SystemMetrics system_metrics_end = shamsys::get_system_metrics();
2785 shamsys::SystemMetrics system_metrics_delta = system_metrics_end - system_metrics_start;
2786
2787 f64 t_dev_alloc
2788 = (mem_perf_infos_end.time_alloc_device - mem_perf_infos_start.time_alloc_device)
2789 + (mem_perf_infos_end.time_free_device - mem_perf_infos_start.time_free_device);
2790 f64 t_host_alloc = (mem_perf_infos_end.time_alloc_host - mem_perf_infos_start.time_alloc_host)
2791 + (mem_perf_infos_end.time_free_host - mem_perf_infos_start.time_free_host);
2792
2793 u64 rank_count = scheduler().get_rank_count();
2794 f64 rate = f64(rank_count) / tstep.elapsed_sec();
2795
2796 u64 npatch = scheduler().patch_list.local.size();
2797
2798 // logger::info_ln("SPHSolver", "process rate : ", rate, "particle.s-1");
2799
2800 std::string log_step = report_perf_timestep(
2801 rate,
2802 rank_count,
2803 npatch,
2804 tstep.elapsed_sec(),
2805 delta_mpi_timer,
2806 t_dev_alloc,
2807 t_host_alloc,
2808 mem_perf_infos_end.max_allocated_byte_device,
2809 mem_perf_infos_end.max_allocated_byte_host,
2810 system_metrics_delta,
2811 shamsys::has_reporter());
2812
2813 if (shamcomm::world_rank() == 0) {
2814 logger::info_ln("sph::Model", log_step);
2816 "sph::Model", "estimated rate :", dt * (3600 / tstep.elapsed_sec()), "(tsim/hr)");
2817 }
2818
2819 solve_logs.register_log(
2820 {t_current, // f64 solver_t;
2821 dt, // f64 solver_dt;
2822 shamcomm::world_rank(), // i32 world_rank;
2823 rank_count, // u64 rank_count;
2824 rate, // f64 rate;
2825 tstep.elapsed_sec(), // f64 elapsed_sec;
2827 system_metrics_delta});
2828
2829 storage.timings_details.reset();
2830
2831 reset_serial_patch_tree();
2832 reset_ghost_handler();
2833
2834 shambase::get_check_ref(storage.part_counts).free_alloc();
2835 shambase::get_check_ref(storage.part_counts_with_ghost).free_alloc();
2836 shambase::get_check_ref(storage.positions_with_ghosts).free_alloc();
2837 shambase::get_check_ref(storage.hpart_with_ghosts).free_alloc();
2838 storage.merged_xyzh.reset();
2839 shambase::get_check_ref(storage.omega).free_alloc();
2840 clear_merged_pos_trees();
2841 clear_ghost_cache();
2842 reset_presteps_rint();
2843 reset_neighbors_cache();
2844
2845 shambase::get_check_ref(storage.neigh_cache).free_alloc();
2846
2847 solver_config.set_next_dt(next_cfl);
2848 solver_config.set_time(t_current + dt);
2849
2850 auto get_next_cfl_mult = [&]() {
2851 Tscal cfl_m = solver_config.time_state.cfl_multiplier;
2852 Tscal stiff = solver_config.cfl_config.cfl_multiplier_stiffness;
2853
2854 return (cfl_m * stiff + 1.) / (stiff + 1.);
2855 };
2856
2857 solver_config.time_state.cfl_multiplier = get_next_cfl_mult();
2858
2859 TimestepLog log;
2860 log.rank = shamcomm::world_rank();
2861 log.rate = rate;
2862 log.npart = rank_count;
2863 log.tcompute = tstep.elapsed_sec();
2864
2865 return log;
2866}
2867
2868using namespace shammath;
2869
2873
A module to compute and display statistics on neighbor counts for SPH particles.
Defines the CopyPatchDataFieldFromLayer class for copying fields between patch data layers.
Defines the DistributedBuffers class for managing distributed device buffers in a solver graph.
Implements a forward Euler integration step as a solver graph node.
Defines the GetFieldRefFromLayer class for extracting field references from patch data layers.
Defines the GetObjCntFromLayer class for extracting object counts from patch data layers.
Declares the GetParticlesOutsideSphere module for removing particles.
shambase::DistributedData< PatchDataFieldRef< T > > DDPatchDataFieldRef
Alias for a DistributedData of PatchDataFieldRefs.
Declares the IterateSmoothingLengthDensityNeighLim module for iterating smoothing length based on the...
Declares the IterateSmoothingLengthDensity module for iterating smoothing length based on the SPH den...
Declares the KillParticles module for removing particles.
Declares the LoopSmoothingLengthIter module for looping over the smoothing length iteration until con...
Field variant object to instanciate a variant on the patch types.
Header file describing a Node Instance.
Node that applies a custom function to modify connected edges.
Defines the PatchDataLayerRefs class for managing distributed references to patch data layers.
MPI scheduler.
Header file for the patch struct and related function.
Declare a class to register and retrieve nodes and edges from a unique container.
double f64
Alias for double.
std::uint32_t u32
32 bit unsigned integer
std::uint64_t u64
64 bit unsigned integer
std::int32_t i32
32 bit integer
Shamrock assertion utility.
#define SHAM_ASSERT(x)
Shorthand for SHAM_ASSERT_NAMED without a message.
Definition assert.hpp:67
The MPI scheduler.
A buffer allocated in USM (Unified Shared Memory).
void complete_event_state(sycl::event e) const
Complete the event state of the buffer.
DeviceQueue & get_queue() const
Gets the DeviceQueue associated with the held allocation.
T * get_write_access(sham::EventList &depends_list, SourceLocation src_loc=SourceLocation{})
Get a read-write pointer to the buffer's data.
void fill(T value, std::array< size_t, 2 > idx_range)
Fill a subpart of the buffer with a given value.
const T * get_read_access(sham::EventList &depends_list, SourceLocation src_loc=SourceLocation{}) const
Get a read-only pointer to the buffer's data.
Class to manage a list of SYCL events.
Definition EventList.hpp:31
void add_event(sycl::event e)
Add an event to the list of events.
Definition EventList.hpp:87
Container for objects shared between two distributed data elements.
void for_each(std::function< void(u64, u64, T &)> &&f)
Apply a function to all stored objects.
Represents a collection of objects distributed across patches identified by a u64 id.
iterator add_obj(u64 id, T &&obj)
Adds a new object to the collection.
T & get(u64 id)
Returns a reference to an object in the collection.
Class Timer measures the time elapsed since the timer was started.
Definition time.hpp:35
f64 elapsed_sec() const
Converts the stored nanosecond time to a floating point representation in seconds.
Definition time.hpp:87
void start()
Starts the timer.
Definition time.hpp:50
void stop()
Stops the timer and stores the elapsed time in nanoseconds.
Definition time.hpp:64
Vector class based on std::array storage and mdspan.
Definition matrix.hpp:96
handle basic utilities dealing with SPH
The shamrock SPH model.
Definition Solver.hpp:70
void reset_presteps_rint()
Resets tree radius interval field.
Definition Solver.cpp:1196
void reset_merge_ghosts_fields()
Resets merged ghost field data.
Definition Solver.cpp:1473
void update_sync_load_values()
Updates load balancing values and synchronizes patch ownership.
Definition Solver.cpp:1565
bool apply_corrector(Tscal dt, u64 Npart_all)
Definition Solver.cpp:1560
void merge_position_ghost()
Merges ghost particle positions from neighboring patches.
Definition Solver.cpp:843
void reset_eos_fields()
Frees memory allocated for EOS fields.
Definition Solver.cpp:1499
void prepare_corrector()
Saves old derivative fields for predictor-corrector integration.
Definition Solver.cpp:1505
void build_ghost_cache()
Builds ghost particle interface cache for inter-patch communication.
Definition Solver.cpp:821
void update_artificial_viscosity(Tscal dt)
Updates artificial viscosity coefficients for shock capturing.
Definition Solver.cpp:1482
TimestepLog evolve_once()
Performs one complete SPH timestep evolution.
Definition Solver.cpp:1623
void vtk_do_dump(std::string filename, bool add_patch_world_id)
Writes VTK dump file for visualization.
Definition Solver.cpp:598
void update_derivs(Tscal dt_hydro)
Updates time derivatives and applies external forces.
Definition Solver.cpp:1550
void build_merged_pos_trees()
Builds spatial BVH trees for merged positions including ghosts.
Definition Solver.cpp:886
void clear_merged_pos_trees()
Clears merged position trees to free memory.
Definition Solver.cpp:891
void init_solver_graph()
Initializes the solver graph for computation pipeline.
Definition Solver.cpp:112
void sph_prestep(Tscal time_val, Tscal dt)
Performs pre-step operations for SPH timestep.
Definition Solver.cpp:897
void compute_presteps_rint()
Computes maximum smoothing length in tree nodes for neighbor search.
Definition Solver.cpp:1159
void compute_eos_fields()
Computes equation of state fields (pressure, sound speed).
Definition Solver.cpp:1493
void apply_position_boundary(Tscal time_val)
Applies position-based boundary conditions.
Definition Solver.cpp:775
void reset_neighbors_cache()
Resets neighbor cache.
Definition Solver.cpp:1226
void communicate_merge_ghosts_fields()
Communicates and merges ghost particle fields across processes.
Definition Solver.cpp:1231
void clear_ghost_cache()
Clears ghost particle cache to free memory.
Definition Solver.cpp:837
void init_ghost_layout()
Initializes data layout for ghost particle fields.
Definition Solver.cpp:1144
void start_neighbors_cache()
Builds neighbor particle cache for SPH calculations.
Definition Solver.cpp:1201
Module for constructing spatial tree structures for SPH neighbor searches.
void build_merged_pos_trees()
Builds compressed leaf BVH trees for merged particle positions including ghosts.
Module for computing equation of state quantities.
void compute_eos()
Computes pressure and sound speed from equation of state.
Module for checking conservation of physical quantities.
void check_conservation()
Verifies conservation of mass, momentum, and energy.
void add_ext_forces()
add external forces to the particle acceleration, note that forces dependant on velocity shlould be a...
void compute_ext_forces_indep_v()
is ran once per timestep, it computes the forces that are independant of velocity
Module for reordering particles to improve cache locality.
void reorder_particles()
Reorders particles by Morton code for improved memory access patterns.
Module for writing VTK format output files.
Definition VTKDump.hpp:33
void do_dump(std::string filename, bool add_patch_world_id)
Writes particle data to VTK file for visualization.
Definition VTKDump.cpp:37
Utility class used to move the objects between patches.
void reatribute_patch_objects(SerialPatchTree< T > &sptree, std::string position_field)
Reattribute objects based on a given position field.
ComputeField< T > make_compute_field(std::string new_name, u32 nvar)
create a compute field and init it to zeros
ComputeField< T > save_field(u32 field_idx, std::string new_name)
save a field in patchdata to a compute field
u32 get_field_idx(const std::string &field_name) const
Get the field id if matching name & type.
PatchDataLayer container class, the layout is described in patchdata_layout.
Interface for a solver graph edge representing a field as spans.
PatchDataField< T > & get_field(u64 id) const
Get the underlying PatchDataField at the given id.
void evaluate()
Evaluate the node.
Definition INode.hpp:109
A node that simply frees the allocation of the connected node.
A node that applies a custom function to modify connected edges.
void set_edges(std::shared_ptr< IEdge > to_set)
Set the edges of the node.
virtual void free_alloc() override
Free allocated memory.
A graph container for managing solver nodes and edges with type-safe access.
std::shared_ptr< INode > & get_node_ptr_base(const std::string &name)
Retrieve a node by name as a shared pointer to the base interface.
T & get_edge_ref(const std::string &name)
Get a typed reference to an edge by name.
std::shared_ptr< T > get_edge_ptr(const std::string &name)
Get a typed shared pointer to an edge by name.
std::shared_ptr< T > register_edge(const std::string &name, T &&edge)
Register an edge with automatic type deduction and shared pointer creation.
std::shared_ptr< T > register_node(const std::string &name, T &&node)
Register a node with automatic type deduction and shared pointer creation.
INode & get_node_ref_base(const std::string &name)
Get a reference to a node by name through the base interface.
A Compressed Leaf Bounding Volume Hierarchy (CLBVH) for neighborhood queries.
A data structure representing a Karras Radix Tree Field.
This header file contains utility functions related to exception handling in the code.
MPI string gather / allgather helpers (declarations; implementations in shamalgs/src/collective/gathe...
MemPerfInfos get_mem_perf_info()
Retrieve the memory performance information.
This file contains the declaration of the memory handling and its methods.
void kernel_call(sham::DeviceQueue &q, RefIn in, RefOut in_out, u32 n, Functor &&func, SourceLocation &&callsite=SourceLocation{})
Submit a kernel to a SYCL queue.
std::vector< T > buf_to_vec(sycl::buffer< T > &buf, u32 len)
Convert a sycl::buffer to a std::vector.
Definition memory.cpp:34
void throw_with_loc(std::string message, SourceLocation loc=SourceLocation{})
Throw an exception and append the source location to it.
T & get_check_ref(const std::unique_ptr< T > &ptr, SourceLocation loc=SourceLocation())
Takes a std::unique_ptr and returns a reference to the object it holds. It throws a std::runtime_erro...
Definition memory.hpp:110
ExcptTypes make_except_with_loc(std::string message, SourceLocation loc=SourceLocation{})
Create an exception with a message and a location.
void throw_unimplemented(SourceLocation loc=SourceLocation{})
Throw a std::runtime_error saying that the function is unimplemented.
i32 world_rank()
Gives the rank of the current process in the MPI communicator.
Definition worldInfo.cpp:40
namespace for math utility
Definition AABB.hpp:26
namespace for the sph model
namespace for the main framework
Definition __init__.py:1
void info(std::string module_name, Types... var2)
Prints a log message with multiple arguments.
Definition logs.hpp:133
void raw_ln(Types... var2)
Prints a log message with multiple arguments followed by a newline.
Definition logs.hpp:90
void info_ln(std::string module_name, Types... var2)
Prints a log message with multiple arguments followed by a newline.
Definition logs.hpp:133
void warn_ln(std::string module_name, Types... var2)
Prints a log message with multiple arguments followed by a newline.
Definition logs.hpp:133
void err_ln(std::string module_name, Types... var2)
Prints a log message with multiple arguments followed by a newline.
Definition logs.hpp:133
file containing formulas for sph forces
sph kernels
shambase::details::NamedBasicStackEntry NamedStackEntry
Alias for shambase::details::NamedBasicStackEntry.
shambase::details::BasicStackEntry StackEntry
Alias for shambase::details::BasicStackEntry.
f64 get_wtime()
Returns the current wall clock time in seconds.
Structure to store the performance informations about memory allocation and deallocation.
f64 time_alloc_host
Time spent allocating memory on the host.
size_t max_allocated_byte_host
max bytes allocated on the host
f64 time_free_device
Time spent deallocating memory on the device.
size_t max_allocated_byte_device
max bytes allocated on the device
f64 time_alloc_device
Time spent allocating memory on the device.
f64 time_free_host
Time spent deallocating memory on the host.
A class that references multiple buffers or similar objects.
A class to represent a single block of data in a Phantom dump.
u64 get_ref_f32(std::string s)
Gets the index of a block of type f32 with the given name.
u64 get_ref_fort_real(std::string s)
Gets the index of a block of type fort_real with the given name.
i64 tot_count
The total number of values in the block.
std::vector< PhantomDumpBlockArray< fort_real > > blocks_fort_real
The blocks of values of type fort_real.
std::vector< PhantomDumpBlockArray< f32 > > blocks_f32
The blocks of values of type f32.
Class representing a Phantom dump file.
void override_magic_number()
Overrides the magic numbers used in the PhantomDump struct.
BCConfig< Tvec > BCConfig
Configuration of the boundary conditions.
Patch object that contain generic patch information.
Definition Patch.hpp:33
u64 id_patch
unique key that identify the patch
Definition Patch.hpp:86
Functions related to the MPI communicator.
f64 get_timer(std::string timername)
get a timer value
Definition wrapper.cpp:44