Shamrock 2025.10.0
Astrophysical Code
Loading...
Searching...
No Matches
ConservativeCheck.cpp
Go to the documentation of this file.
1// -------------------------------------------------------//
2//
3// SHAMROCK code for hydrodynamics
4// Copyright (c) 2021-2026 Timothée David--Cléris <tim.shamrock@proton.me>
5// SPDX-License-Identifier: CeCILL Free Software License Agreement v2.1
6// Shamrock is licensed under the CeCILL 2.1 License, see LICENSE for more information
7//
8// -------------------------------------------------------//
9
19#include "shamcomm/logs.hpp"
23
24template<class Tvec, template<class> class SPHKernel>
26
27 StackEntry stack_loc{};
28
29 auto dev_sched = shamsys::instance::get_compute_scheduler_ptr();
30 sham::DeviceQueue &q = shambase::get_check_ref(dev_sched).get_queue();
31
32 Tscal gpart_mass = solver_config.gpart_mass;
33
34 using namespace shamrock;
35 using namespace shamrock::patch;
36 using Sink = SinkParticle<Tvec>;
37
38 PatchDataLayerLayout &pdl = scheduler().pdl_old();
39
40 const u32 ixyz = pdl.get_field_idx<Tvec>("xyz");
41 const u32 ivxyz = pdl.get_field_idx<Tvec>("vxyz");
42 const u32 iaxyz = pdl.get_field_idx<Tvec>("axyz");
43 const u32 iaxyz_ext = pdl.get_field_idx<Tvec>("axyz_ext");
44 const u32 iuint = pdl.get_field_idx<Tscal>("uint");
45 const u32 iduint = pdl.get_field_idx<Tscal>("duint");
46 const u32 ihpart = pdl.get_field_idx<Tscal>("hpart");
47
48 bool has_B_field = solver_config.has_field_B_on_rho();
49 const u32 iB_on_rho = (has_B_field) ? pdl.get_field_idx<Tvec>("B/rho") : -1;
50 const u32 idB_on_rho = (has_B_field) ? pdl.get_field_idx<Tvec>("dB/rho") : -1;
51 const u32 idrho_dt = (has_B_field) ? pdl.get_field_idx<Tscal>("drho/dt") : -1;
52
53 std::string cv_checks = "conservation infos :\n";
54
56 // momentum check :
58 Tvec tmpp{0, 0, 0};
59 scheduler().for_each_patchdata_nonempty([&](Patch cur_p, PatchDataLayer &pdat) {
60 PatchDataField<Tvec> &field = pdat.get_field<Tvec>(ivxyz);
61 tmpp += field.compute_sum();
62 });
63 Tvec sum_p = gpart_mass * shamalgs::collective::allreduce_sum(tmpp);
64
65 if (shamcomm::world_rank() == 0) {
66 if (!storage.sinks.is_empty()) {
67 std::vector<Sink> &sink_parts = storage.sinks.get();
68 for (Sink &s : sink_parts) {
69 sum_p += s.mass * s.velocity;
70 }
71 }
72 cv_checks += shambase::format(" sum v = {}\n", sum_p);
73 }
74
76 // force sum check :
78 Tvec tmpa{0, 0, 0};
79 scheduler().for_each_patchdata_nonempty([&](Patch cur_p, PatchDataLayer &pdat) {
80 PatchDataField<Tvec> &field = pdat.get_field<Tvec>(iaxyz);
81 tmpa += field.compute_sum();
82 });
83 Tvec sum_a = gpart_mass * shamalgs::collective::allreduce_sum(tmpa);
84
85 if (shamcomm::world_rank() == 0) {
86 if (!storage.sinks.is_empty()) {
87 std::vector<Sink> &sink_parts = storage.sinks.get();
88 for (Sink &s : sink_parts) {
89 sum_a += s.mass * (s.sph_acceleration + s.ext_acceleration);
90 }
91 }
92 cv_checks += shambase::format(" sum a = {}\n", sum_a);
93 }
94
96 // energy check :
98 Tscal tmpe{0};
99 scheduler().for_each_patchdata_nonempty([&](Patch cur_p, PatchDataLayer &pdat) {
100 PatchDataField<Tscal> &field_u = pdat.get_field<Tscal>(iuint);
101 PatchDataField<Tvec> &field_v = pdat.get_field<Tvec>(ivxyz);
102 tmpe += field_u.compute_sum() + 0.5 * field_v.compute_dot_sum();
103 });
104 Tscal sum_e = gpart_mass * shamalgs::collective::allreduce_sum(tmpe);
105
106 if (shamcomm::world_rank() == 0) {
107 cv_checks += shambase::format(" sum e = {}\n", sum_e);
108 }
109
110 Tscal pmass = gpart_mass;
111 Tscal tmp_de = 0;
112 scheduler().for_each_patchdata_nonempty([&, pmass](Patch cur_p, PatchDataLayer &pdat) {
113 PatchDataField<Tvec> &field_v = pdat.get_field<Tvec>(ivxyz);
114 PatchDataField<Tscal> &field_du = pdat.get_field<Tscal>(iduint);
115 PatchDataField<Tvec> &field_a = pdat.get_field<Tvec>(iaxyz);
116 PatchDataField<Tscal> &field_hpart = pdat.get_field<Tscal>(ihpart);
117
118 sham::DeviceBuffer<Tscal> temp_de(pdat.get_obj_cnt(), dev_sched);
119
120 Tscal const mu_0 = solver_config.get_constant_mu_0();
121
123 q,
124 sham::MultiRef{field_du.get_buf(), field_v.get_buf(), field_a.get_buf()},
125 sham::MultiRef{temp_de},
126 pdat.get_obj_cnt(),
127 [=](u32 item, const Tscal *du, const Tvec *v, const Tvec *a, Tscal *de) {
128 de[item] = pmass * (sycl::dot(v[item], a[item]) + du[item]);
129 });
130
131 if (has_B_field) {
132 PatchDataField<Tvec> &field_B_on_rho = pdat.get_field<Tvec>(iB_on_rho);
133 PatchDataField<Tvec> &field_dB_on_rho = pdat.get_field<Tvec>(idB_on_rho);
134 PatchDataField<Tscal> &field_drho_dt = pdat.get_field<Tscal>(idrho_dt);
135
137 q,
139 field_hpart.get_buf(),
140 field_B_on_rho.get_buf(),
141 field_dB_on_rho.get_buf(),
142 field_drho_dt.get_buf()},
143 sham::MultiRef{temp_de},
144 pdat.get_obj_cnt(),
145 [=](u32 item,
146 const Tscal *hpart,
147 const Tvec *B_on_rho,
148 const Tvec *dB_on_rho,
149 const Tscal *drho_dt,
150 Tscal *de) {
151 using namespace shamrock::sph;
152 Tscal h = hpart[item];
153 Tscal term_B = 0.;
154
155 Tvec B_on_rho_a = B_on_rho[item];
156 Tvec B = B_on_rho_a * shamrock::sph::rho_h(pmass, h, Kernel::hfactd);
157 Tvec dB_on_rho_a = dB_on_rho[item];
158 Tscal drho = drho_dt[item];
159 term_B = 0.5 * (1. / mu_0) * sycl::dot(B_on_rho_a, B_on_rho_a) * drho
160 + (1. / mu_0) * sycl::dot(B, dB_on_rho_a);
161
162 de[item] += pmass * term_B;
163 });
164 }
165
166 Tscal de_p = shamalgs::primitives::sum(dev_sched, temp_de, 0, pdat.get_obj_cnt());
167 tmp_de += de_p;
168 });
169
170 Tscal de = shamalgs::collective::allreduce_sum(tmp_de);
171
172 if (shamcomm::world_rank() == 0) {
173 cv_checks += shambase::format(" sum de = {}", de);
174 }
175
176 if (shamcomm::world_rank() == 0) {
177 logger::info_ln("sph::Model", cv_checks);
178 }
179}
180
181using namespace shammath;
185
std::uint32_t u32
32 bit unsigned integer
A buffer allocated in USM (Unified Shared Memory)
A SYCL queue associated with a device and a context.
Module for checking conservation of physical quantities.
void check_conservation()
Verifies conservation of mass, momentum, and energy.
u32 get_field_idx(const std::string &field_name) const
Get the field id if matching name & type.
PatchDataLayer container class, the layout is described in patchdata_layout.
void kernel_call(sham::DeviceQueue &q, RefIn in, RefOut in_out, u32 n, Functor &&func, SourceLocation &&callsite=SourceLocation{})
Submit a kernel to a SYCL queue.
T sum(const sham::DeviceScheduler_ptr &sched, const sham::DeviceBuffer< T > &buf1, u32 start_id, u32 end_id)
Compute the sum of elements in a device buffer within a specified range.
T & get_check_ref(const std::unique_ptr< T > &ptr, SourceLocation loc=SourceLocation())
Takes a std::unique_ptr and returns a reference to the object it holds. It throws a std::runtime_erro...
Definition memory.hpp:110
i32 world_rank()
Gives the rank of the current process in the MPI communicator.
Definition worldInfo.cpp:40
namespace for math utility
Definition AABB.hpp:26
namespace for the main framework
Definition __init__.py:1
sph kernels
A class that references multiple buffers or similar objects.
Patch object that contain generic patch information.
Definition Patch.hpp:33