Shamrock 2025.10.0
Astrophysical Code
Loading...
Searching...
No Matches
DiffOperator.cpp
Go to the documentation of this file.
1// -------------------------------------------------------//
2//
3// SHAMROCK code for hydrodynamics
4// Copyright (c) 2021-2026 Timothée David--Cléris <tim.shamrock@proton.me>
5// SPDX-License-Identifier: CeCILL Free Software License Agreement v2.1
6// Shamrock is licensed under the CeCILL 2.1 License, see LICENSE for more information
7//
8// -------------------------------------------------------//
9
23
24template<class Tvec, template<class> class SPHKernel>
26
27 StackEntry stack_loc{};
28 shamlog_debug_ln("SPH", "Updating divv");
29
30 Tscal gpart_mass = solver_config.gpart_mass;
31
32 using namespace shamrock;
33 using namespace shamrock::patch;
34
35 PatchDataLayerLayout &pdl = scheduler().pdl_old();
36
37 shambase::DistributedData<PatchDataLayer> &mpdats = storage.merged_patchdata_ghost.get();
38
39 auto &merged_xyzh = storage.merged_xyzh.get();
40
42 = shambase::get_check_ref(storage.ghost_layout.get());
43 u32 ihpart_interf = ghost_layout.get_field_idx<Tscal>("hpart");
44 u32 iuint_interf = ghost_layout.get_field_idx<Tscal>("uint");
45 u32 ivxyz_interf = ghost_layout.get_field_idx<Tvec>("vxyz");
46 u32 iomega_interf = ghost_layout.get_field_idx<Tscal>("omega");
47
48 const u32 idivv = pdl.get_field_idx<Tscal>("divv");
49 scheduler().for_each_patchdata_nonempty([&](Patch cur_p, PatchDataLayer &pdat) {
50 PatchDataLayer &mpdat = mpdats.get(cur_p.id_patch);
51
53 = merged_xyzh.get(cur_p.id_patch).template get_field_buf_ref<Tvec>(0);
54 sham::DeviceBuffer<Tvec> &buf_vxyz = mpdat.get_field_buf_ref<Tvec>(ivxyz_interf);
55 sham::DeviceBuffer<Tscal> &buf_hpart = mpdat.get_field_buf_ref<Tscal>(ihpart_interf);
56 sham::DeviceBuffer<Tscal> &buf_omega = mpdat.get_field_buf_ref<Tscal>(iomega_interf);
57 sham::DeviceBuffer<Tscal> &buf_uint = mpdat.get_field_buf_ref<Tscal>(iuint_interf);
58 sham::DeviceBuffer<Tscal> &buf_divv = pdat.get_field_buf_ref<Tscal>(idivv);
59
60 sycl::range range_npart{pdat.get_obj_cnt()};
61
62 tree::ObjectCache &pcache
63 = shambase::get_check_ref(storage.neigh_cache).get_cache(cur_p.id_patch);
64
66
67 {
68 NamedStackEntry tmppp{"compute divv"};
69
70 sham::EventList depends_list;
71
72 auto xyz = buf_xyz.get_read_access(depends_list);
73 auto vxyz = buf_vxyz.get_read_access(depends_list);
74 auto hpart = buf_hpart.get_read_access(depends_list);
75 auto omega = buf_omega.get_read_access(depends_list);
76 auto divv = buf_divv.get_write_access(depends_list);
77 auto ploop_ptrs = pcache.get_read_access(depends_list);
78
79 sham::DeviceQueue &q = shamsys::instance::get_compute_scheduler().get_queue();
80
81 auto e = q.submit(depends_list, [&](sycl::handler &cgh) {
82 const Tscal pmass = gpart_mass;
83
84 tree::ObjectCacheIterator particle_looper(ploop_ptrs);
85
86 constexpr Tscal Rker2 = Kernel::Rkern * Kernel::Rkern;
87
88 shambase::parallel_for(cgh, pdat.get_obj_cnt(), "compute divv", [=](i32 id_a) {
89 using namespace shamrock::sph;
90
91 Tvec sum_axyz = {0, 0, 0};
92 Tscal sum_du_a = 0;
93 Tscal h_a = hpart[id_a];
94 Tvec xyz_a = xyz[id_a];
95 Tvec vxyz_a = vxyz[id_a];
96 Tscal omega_a = omega[id_a];
97
98 Tscal rho_a = rho_h(pmass, h_a, Kernel::hfactd);
99 // Tscal rho_a_sq = rho_a * rho_a;
100 // Tscal rho_a_inv = 1. / rho_a;
101 Tscal inv_rho_omega_a = 1. / (omega_a * rho_a);
102
103 Tscal sum_nabla_v = 0;
104
105 particle_looper.for_each_object(id_a, [&](u32 id_b) {
106 // compute only omega_a
107 Tvec dr = xyz_a - xyz[id_b];
108 Tscal rab2 = sycl::dot(dr, dr);
109 Tscal h_b = hpart[id_b];
110
111 if (rab2 > h_a * h_a * Rker2 && rab2 > h_b * h_b * Rker2) {
112 return;
113 }
114
115 Tscal rab = sycl::sqrt(rab2);
116 Tvec vxyz_b = vxyz[id_b];
117 Tvec v_ab = vxyz_a - vxyz_b;
118
119 Tvec r_ab_unit = dr / rab;
120
121 if (rab < 1e-9) {
122 r_ab_unit = {0, 0, 0};
123 }
124
125 Tvec dWab_a = Kernel::dW_3d(rab, h_a) * r_ab_unit;
126
127 sum_nabla_v += pmass * sycl::dot(v_ab, dWab_a);
128 });
129
130 divv[id_a] = -inv_rho_omega_a * sum_nabla_v;
131 });
132 });
133
134 buf_xyz.complete_event_state(e);
135 buf_vxyz.complete_event_state(e);
136 buf_hpart.complete_event_state(e);
137 buf_omega.complete_event_state(e);
138 buf_divv.complete_event_state(e);
139
140 sham::EventList resulting_events;
141 resulting_events.add_event(e);
142 pcache.complete_event_state(resulting_events);
143 }
144 });
145}
146
147template<class Tvec, template<class> class SPHKernel>
149
150 StackEntry stack_loc{};
151 shamlog_debug_ln("SPH", "Updating curlv");
152
153 Tscal gpart_mass = solver_config.gpart_mass;
154
155 using namespace shamrock;
156 using namespace shamrock::patch;
157
158 PatchDataLayerLayout &pdl = scheduler().pdl_old();
159
160 shambase::DistributedData<PatchDataLayer> &mpdats = storage.merged_patchdata_ghost.get();
161
162 auto &merged_xyzh = storage.merged_xyzh.get();
163
165 = shambase::get_check_ref(storage.ghost_layout.get());
166 u32 ihpart_interf = ghost_layout.get_field_idx<Tscal>("hpart");
167 u32 iuint_interf = ghost_layout.get_field_idx<Tscal>("uint");
168 u32 ivxyz_interf = ghost_layout.get_field_idx<Tvec>("vxyz");
169 u32 iomega_interf = ghost_layout.get_field_idx<Tscal>("omega");
170
171 const u32 icurlv = pdl.get_field_idx<Tvec>("curlv");
172 scheduler().for_each_patchdata_nonempty([&](Patch cur_p, PatchDataLayer &pdat) {
173 PatchDataLayer &mpdat = mpdats.get(cur_p.id_patch);
174
176 = merged_xyzh.get(cur_p.id_patch).template get_field_buf_ref<Tvec>(0);
177 sham::DeviceBuffer<Tvec> &buf_vxyz = mpdat.get_field_buf_ref<Tvec>(ivxyz_interf);
178 sham::DeviceBuffer<Tscal> &buf_hpart = mpdat.get_field_buf_ref<Tscal>(ihpart_interf);
179 sham::DeviceBuffer<Tscal> &buf_omega = mpdat.get_field_buf_ref<Tscal>(iomega_interf);
180 sham::DeviceBuffer<Tscal> &buf_uint = mpdat.get_field_buf_ref<Tscal>(iuint_interf);
181 sham::DeviceBuffer<Tvec> &buf_curlv = pdat.get_field_buf_ref<Tvec>(icurlv);
182
183 sycl::range range_npart{pdat.get_obj_cnt()};
184
185 tree::ObjectCache &pcache
186 = shambase::get_check_ref(storage.neigh_cache).get_cache(cur_p.id_patch);
187
189
190 {
191 NamedStackEntry tmppp{"compute curlv"};
192
193 sham::EventList depends_list;
194 auto xyz = buf_xyz.get_read_access(depends_list);
195 auto vxyz = buf_vxyz.get_read_access(depends_list);
196 auto hpart = buf_hpart.get_read_access(depends_list);
197 auto omega = buf_omega.get_read_access(depends_list);
198 auto curlv = buf_curlv.get_write_access(depends_list);
199 auto ploop_ptrs = pcache.get_read_access(depends_list);
200
201 sham::DeviceQueue &q = shamsys::instance::get_compute_scheduler().get_queue();
202
203 auto e = q.submit(depends_list, [&](sycl::handler &cgh) {
204 const Tscal pmass = gpart_mass;
205
206 tree::ObjectCacheIterator particle_looper(ploop_ptrs);
207
208 constexpr Tscal Rker2 = Kernel::Rkern * Kernel::Rkern;
209
210 shambase::parallel_for(cgh, pdat.get_obj_cnt(), "compute curlv", [=](i32 id_a) {
211 using namespace shamrock::sph;
212
213 Tvec sum_axyz = {0, 0, 0};
214 Tscal sum_du_a = 0;
215 Tscal h_a = hpart[id_a];
216 Tvec xyz_a = xyz[id_a];
217 Tvec vxyz_a = vxyz[id_a];
218 Tscal omega_a = omega[id_a];
219
220 Tscal rho_a = rho_h(pmass, h_a, Kernel::hfactd);
221 // Tscal rho_a_sq = rho_a * rho_a;
222 // Tscal rho_a_inv = 1. / rho_a;
223 Tscal inv_rho_omega_a = 1. / (omega_a * rho_a);
224
225 Tvec sum_nabla_cross_v{};
226
227 particle_looper.for_each_object(id_a, [&](u32 id_b) {
228 // compute only omega_a
229 Tvec dr = xyz_a - xyz[id_b];
230 Tscal rab2 = sycl::dot(dr, dr);
231 Tscal h_b = hpart[id_b];
232
233 if (rab2 > h_a * h_a * Rker2 && rab2 > h_b * h_b * Rker2) {
234 return;
235 }
236
237 Tscal rab = sycl::sqrt(rab2);
238 Tvec vxyz_b = vxyz[id_b];
239 Tvec v_ab = vxyz_a - vxyz_b;
240
241 Tvec r_ab_unit = dr / rab;
242
243 if (rab < 1e-9) {
244 r_ab_unit = {0, 0, 0};
245 }
246
247 Tvec dWab_a = Kernel::dW_3d(rab, h_a) * r_ab_unit;
248
249 sum_nabla_cross_v += pmass * sycl::cross(v_ab, dWab_a);
250 });
251
252 curlv[id_a] = -inv_rho_omega_a * sum_nabla_cross_v;
253 });
254 });
255
256 buf_xyz.complete_event_state(e);
257 buf_vxyz.complete_event_state(e);
258 buf_hpart.complete_event_state(e);
259 buf_omega.complete_event_state(e);
260 buf_curlv.complete_event_state(e);
261
262 sham::EventList resulting_events;
263 resulting_events.add_event(e);
264 pcache.complete_event_state(resulting_events);
265 }
266 });
267}
268
269using namespace shammath;
273
constexpr const char * vxyz
3-velocity field
constexpr const char * xyz
Position field (3D coordinates)
constexpr const char * hpart
Smoothing length field.
constexpr const char * omega
Grad-h correction factor \Omega.
std::uint32_t u32
32 bit unsigned integer
std::int32_t i32
32 bit integer
A buffer allocated in USM (Unified Shared Memory)
void complete_event_state(sycl::event e) const
Complete the event state of the buffer.
const T * get_read_access(sham::EventList &depends_list, SourceLocation src_loc=SourceLocation{}) const
Get a read-only pointer to the buffer's data.
A SYCL queue associated with a device and a context.
sycl::event submit(Fct &&fct)
Submits a kernel to the SYCL queue.
DeviceQueue & get_queue(u32 id=0)
Get a reference to a DeviceQueue.
Class to manage a list of SYCL events.
Definition EventList.hpp:31
void add_event(sycl::event e)
Add an event to the list of events.
Definition EventList.hpp:87
Represents a collection of objects distributed across patches identified by a u64 id.
T & get(u64 id)
Returns a reference to an object in the collection.
u32 get_field_idx(const std::string &field_name) const
Get the field id if matching name & type.
PatchDataLayer container class, the layout is described in patchdata_layout.
T & get_check_ref(const std::unique_ptr< T > &ptr, SourceLocation loc=SourceLocation())
Takes a std::unique_ptr and returns a reference to the object it holds. It throws a std::runtime_erro...
Definition memory.hpp:110
namespace for math utility
Definition AABB.hpp:26
namespace for the main framework
Definition __init__.py:1
sph kernels
This file contains the definition for the stacktrace related functionality.
Patch object that contain generic patch information.
Definition Patch.hpp:33
u64 id_patch
unique key that identify the patch
Definition Patch.hpp:86