Shamrock 2025.10.0
Astrophysical Code
Loading...
Searching...
No Matches
DragIntegrator.cpp
Go to the documentation of this file.
1// -------------------------------------------------------//
2//
3// SHAMROCK code for hydrodynamics
4// Copyright (c) 2021-2026 Timothée David--Cléris <tim.shamrock@proton.me>
5// SPDX-License-Identifier: CeCILL Free Software License Agreement v2.1
6// Shamrock is licensed under the CeCILL 2.1 License, see LICENSE for more information
7//
8// -------------------------------------------------------//
9
19#include "shambase/string.hpp"
27#include <stdexcept>
28
29template<class Tvec, class TgridVec>
31 Tscal dt) {
32
33 StackEntry stack_lock{};
34
35 using namespace shamrock::patch;
36 using namespace shamrock;
37 using namespace shammath;
38
39 const u32 ndust = solver_config.dust_config.ndust;
40
41 SchedulerUtility utility(scheduler());
42 shamrock::ComputeField<Tscal> cfield_rho_next_bf_drag
43 = utility.make_compute_field<Tscal>("rho_next_bf_drag", AMRBlock::block_size);
44 shamrock::ComputeField<Tvec> cfield_rhov_next_bf_drag
45 = utility.make_compute_field<Tvec>("rhov_next_bf_drag", AMRBlock::block_size);
46 shamrock::ComputeField<Tscal> cfield_rhoe_next_bf_drag
47 = utility.make_compute_field<Tscal>("rhoe_next_bf_drag", AMRBlock::block_size);
48 shamrock::ComputeField<Tscal> cfield_rho_d_next_bf_drag
49 = utility.make_compute_field<Tscal>("rho_d_next_bf_drag", ndust * AMRBlock::block_size);
50 shamrock::ComputeField<Tvec> cfield_rhov_d_next_bf_drag
51 = utility.make_compute_field<Tvec>("rhov_d_next_bf_drag", ndust * AMRBlock::block_size);
52
54 shamrock::solvergraph::Field<Tvec> &cfield_dtrhov = shambase::get_check_ref(storage.dtrhov);
55 shamrock::solvergraph::Field<Tscal> &cfield_dtrhoe = shambase::get_check_ref(storage.dtrhoe);
57 = shambase::get_check_ref(storage.dtrho_dust);
59 = shambase::get_check_ref(storage.dtrhov_dust);
60
61 // load layout info
62 PatchDataLayerLayout &pdl = scheduler().pdl_old();
63
64 const u32 icell_min = pdl.get_field_idx<TgridVec>("cell_min");
65 const u32 icell_max = pdl.get_field_idx<TgridVec>("cell_max");
66 const u32 irho = pdl.get_field_idx<Tscal>("rho");
67 const u32 irhoetot = pdl.get_field_idx<Tscal>("rhoetot");
68 const u32 irhovel = pdl.get_field_idx<Tvec>("rhovel");
69 const u32 irho_d = pdl.get_field_idx<Tscal>("rho_dust");
70 const u32 irhovel_d = pdl.get_field_idx<Tvec>("rhovel_dust");
71
72 scheduler().for_each_patchdata_nonempty([&, dt, ndust](
75 shamlog_debug_ln(
76 "[AMR evolve time step before drag ]", "evolve field with no drag patch", p.id_patch);
77
78 sham::DeviceQueue &q = shamsys::instance::get_compute_scheduler().get_queue();
79 u32 id = p.id_patch;
80
81 sham::DeviceBuffer<Tscal> &dt_rho_patch = cfield_dtrho.get_buf(id);
82 sham::DeviceBuffer<Tvec> &dt_rhov_patch = cfield_dtrhov.get_buf(id);
83 sham::DeviceBuffer<Tscal> &dt_rhoe_patch = cfield_dtrhoe.get_buf(id);
84 sham::DeviceBuffer<Tscal> &dt_rho_d_patch = cfield_dtrho_d.get_buf(id);
85 sham::DeviceBuffer<Tvec> &dt_rhov_d_patch = cfield_dtrhov_d.get_buf(id);
86
87 sham::DeviceBuffer<Tscal> &buf_rho = pdat.get_field_buf_ref<Tscal>(irho);
88 sham::DeviceBuffer<Tvec> &buf_rhov = pdat.get_field_buf_ref<Tvec>(irhovel);
89 sham::DeviceBuffer<Tscal> &buf_rhoe = pdat.get_field_buf_ref<Tscal>(irhoetot);
90 sham::DeviceBuffer<Tscal> &buf_rho_d = pdat.get_field_buf_ref<Tscal>(irho_d);
91 sham::DeviceBuffer<Tvec> &buf_rhov_d = pdat.get_field_buf_ref<Tvec>(irhovel_d);
92
93 sham::DeviceBuffer<Tscal> &rho_patch = cfield_rho_next_bf_drag.get_buf_check(id);
94 sham::DeviceBuffer<Tvec> &rhov_patch = cfield_rhov_next_bf_drag.get_buf_check(id);
95 sham::DeviceBuffer<Tscal> &rhoe_patch = cfield_rhoe_next_bf_drag.get_buf_check(id);
96 sham::DeviceBuffer<Tscal> &rho_d_patch = cfield_rho_d_next_bf_drag.get_buf_check(id);
97 sham::DeviceBuffer<Tvec> &rhov_d_patch = cfield_rhov_d_next_bf_drag.get_buf_check(id);
98
99 u32 cell_count = pdat.get_obj_cnt() * AMRBlock::block_size;
100
101 sham::EventList depend_list;
102 auto acc_dt_rho_patch = dt_rho_patch.get_read_access(depend_list);
103 auto acc_dt_rhov_patch = dt_rhov_patch.get_read_access(depend_list);
104 auto acc_dt_rhoe_patch = dt_rhoe_patch.get_read_access(depend_list);
105
106 auto rho = buf_rho.get_read_access(depend_list);
107 auto rhov = buf_rhov.get_read_access(depend_list);
108 auto rhoe = buf_rhoe.get_read_access(depend_list);
109
110 auto acc_rho = rho_patch.get_write_access(depend_list);
111 auto acc_rhov = rhov_patch.get_write_access(depend_list);
112 auto acc_rhoe = rhoe_patch.get_write_access(depend_list);
113
114 auto e1 = q.submit(depend_list, [&, dt](sycl::handler &cgh) {
115 shambase::parallel_for(cgh, cell_count, "evolve field with no drag", [=](u32 id_a) {
116 acc_rho[id_a] = rho[id_a] + dt * acc_dt_rho_patch[id_a];
117 acc_rhov[id_a] = rhov[id_a] + dt * acc_dt_rhov_patch[id_a];
118 acc_rhoe[id_a] = rhoe[id_a] + dt * acc_dt_rhoe_patch[id_a];
119 });
120 });
121
122 dt_rho_patch.complete_event_state(e1);
123 dt_rhov_patch.complete_event_state(e1);
124 dt_rhoe_patch.complete_event_state(e1);
125
126 buf_rho.complete_event_state(e1);
127 buf_rhov.complete_event_state(e1);
128 buf_rhoe.complete_event_state(e1);
129
130 rho_patch.complete_event_state(e1);
131 rhov_patch.complete_event_state(e1);
132 rhoe_patch.complete_event_state(e1);
133
134 sham::EventList depend_list1;
135 auto acc_dt_rho_d_patch = dt_rho_d_patch.get_read_access(depend_list1);
136 auto acc_dt_rhov_d_patch = dt_rhov_d_patch.get_read_access(depend_list1);
137
138 auto rho_d = buf_rho_d.get_read_access(depend_list1);
139 auto rhov_d = buf_rhov_d.get_read_access(depend_list1);
140
141 auto acc_rho_d = rho_d_patch.get_write_access(depend_list1);
142 auto acc_rhov_d = rhov_d_patch.get_write_access(depend_list1);
143
144 auto e2 = q.submit(depend_list1, [&, dt, ndust](sycl::handler &cgh) {
145 shambase::parallel_for(
146 cgh, ndust * cell_count, "dust evolve field no drag", [=](u32 id_a) {
147 acc_rho_d[id_a] = rho_d[id_a] + dt * acc_dt_rho_d_patch[id_a];
148 acc_rhov_d[id_a] = rhov_d[id_a] + dt * acc_dt_rhov_d_patch[id_a];
149 });
150 });
151
152 dt_rho_d_patch.complete_event_state(e2);
153 dt_rhov_d_patch.complete_event_state(e2);
154
155 buf_rho_d.complete_event_state(e2);
156 buf_rhov_d.complete_event_state(e2);
157
158 rho_d_patch.complete_event_state(e2);
159 rhov_d_patch.complete_event_state(e2);
160 });
161
162 storage.rho_next_no_drag.set(std::move(cfield_rho_next_bf_drag));
163 storage.rhov_next_no_drag.set(std::move(cfield_rhov_next_bf_drag));
164 storage.rhoe_next_no_drag.set(std::move(cfield_rhoe_next_bf_drag));
165 storage.rho_d_next_no_drag.set(std::move(cfield_rho_d_next_bf_drag));
166 storage.rhov_d_next_no_drag.set(std::move(cfield_rhov_d_next_bf_drag));
167}
168
169template<class Tvec, class TgridVec>
171 Tscal dt) {
172 StackEntry stack_lock{};
173
174 using namespace shamrock::patch;
175 using namespace shamrock;
176 using namespace shammath;
177
178 shamrock::ComputeField<Tscal> &cfield_rho_new = storage.rho_next_no_drag.get();
179 shamrock::ComputeField<Tvec> &cfield_rhov_new = storage.rhov_next_no_drag.get();
180 shamrock::ComputeField<Tscal> &cfield_rhoe_new = storage.rhoe_next_no_drag.get();
181 shamrock::ComputeField<Tscal> &cfield_rho_d_new = storage.rho_d_next_no_drag.get();
182 shamrock::ComputeField<Tvec> &cfield_rhov_d_new = storage.rhov_d_next_no_drag.get();
183
184 // load layout info
185 PatchDataLayerLayout &pdl = scheduler().pdl_old();
186
187 const u32 icell_min = pdl.get_field_idx<TgridVec>("cell_min");
188 const u32 icell_max = pdl.get_field_idx<TgridVec>("cell_max");
189 const u32 irho = pdl.get_field_idx<Tscal>("rho");
190 const u32 irhoetot = pdl.get_field_idx<Tscal>("rhoetot");
191 const u32 irhovel = pdl.get_field_idx<Tvec>("rhovel");
192 const u32 irho_d = pdl.get_field_idx<Tscal>("rho_dust");
193 const u32 irhovel_d = pdl.get_field_idx<Tvec>("rhovel_dust");
194
195 const u32 ndust = solver_config.dust_config.ndust;
196 // alphas are dust collision rates
197 auto alphas_vector = solver_config.drag_config.alphas;
198 std::vector<Tscal> inv_dt_alphas(ndust);
199 bool enable_frictional_heating = solver_config.drag_config.enable_frictional_heating;
200 u32 friction_control = (enable_frictional_heating == false) ? 1 : 0;
201
202 scheduler().for_each_patchdata_nonempty([&, dt, ndust, friction_control](
205 shamlog_debug_ln("[AMR enable drag ]", "irk1 drag patch", p.id_patch);
206
207 sham::DeviceQueue &q = shamsys::instance::get_compute_scheduler().get_queue();
208 u32 id = p.id_patch;
209 u32 cell_count = pdat.get_obj_cnt() * AMRBlock::block_size;
210
211 sham::DeviceBuffer<Tscal> &rho_new_patch = cfield_rho_new.get_buf_check(id);
212 sham::DeviceBuffer<Tvec> &rhov_new_patch = cfield_rhov_new.get_buf_check(id);
213 sham::DeviceBuffer<Tscal> &rhoe_new_patch = cfield_rhoe_new.get_buf_check(id);
214 sham::DeviceBuffer<Tscal> &rho_d_new_patch = cfield_rho_d_new.get_buf_check(id);
215 sham::DeviceBuffer<Tvec> &rhov_d_new_patch = cfield_rhov_d_new.get_buf_check(id);
216
217 sham::DeviceBuffer<Tscal> &rho_old = pdat.get_field_buf_ref<Tscal>(irho);
218 sham::DeviceBuffer<Tvec> &rhov_old = pdat.get_field_buf_ref<Tvec>(irhovel);
219 sham::DeviceBuffer<Tscal> &rhoe_old = pdat.get_field_buf_ref<Tscal>(irhoetot);
220 sham::DeviceBuffer<Tscal> &rho_d_old = pdat.get_field_buf_ref<Tscal>(irho_d);
221 sham::DeviceBuffer<Tvec> &rhov_d_old = pdat.get_field_buf_ref<Tvec>(irhovel_d);
222
223 sham::DeviceBuffer<Tscal> alphas_buf(ndust, shamsys::instance::get_compute_scheduler_ptr());
224
225 alphas_buf.copy_from_stdvec(alphas_vector);
226
227 sham::EventList depend_list;
228 auto acc_rho_new_patch = rho_new_patch.get_read_access(depend_list);
229 auto acc_rhov_new_patch = rhov_new_patch.get_read_access(depend_list);
230 auto acc_rhoe_new_patch = rhoe_new_patch.get_read_access(depend_list);
231 auto acc_rho_d_new_patch = rho_d_new_patch.get_read_access(depend_list);
232 auto acc_rhov_d_new_patch = rhov_d_new_patch.get_read_access(depend_list);
233
234 auto acc_rho_old = rho_old.get_write_access(depend_list);
235 auto acc_rhov_old = rhov_old.get_write_access(depend_list);
236 auto acc_rhoe_old = rhoe_old.get_write_access(depend_list);
237 auto acc_rho_d_old = rho_d_old.get_write_access(depend_list);
238 auto acc_rhov_d_old = rhov_d_old.get_write_access(depend_list);
239
240 auto acc_alphas = alphas_buf.get_read_access(depend_list);
241
242 auto e = q.submit(depend_list, [&, dt, ndust, friction_control](sycl::handler &cgh) {
243 shambase::parallel_for(cgh, cell_count, "add_drag [irk1]", [=](u32 id_a) {
244 Tvec tmp_mom_1 = acc_rhov_new_patch[id_a];
245 Tscal tmp_rho = acc_rho_old[id_a];
246
247 for (u32 i = 0; i < ndust; i++) {
248 const Tscal inv_dt_alphas = 1.0 / (1.0 + acc_alphas[i] * dt);
249 const Tscal dt_alphas = dt * acc_alphas[i];
250
251 tmp_mom_1
252 = tmp_mom_1
253 + dt_alphas * inv_dt_alphas * acc_rhov_d_new_patch[id_a * ndust + i];
254 tmp_rho = tmp_rho + dt_alphas * inv_dt_alphas * acc_rho_d_old[id_a * ndust + i];
255 }
256
257 Tscal tmp_inv_rho = 1.0 / tmp_rho;
258 Tvec tmp_vel = tmp_inv_rho * tmp_mom_1;
259 Tscal Eg = 0.0;
260
261 Tscal inv_rho_g = 1.0 / acc_rho_new_patch[id_a];
262 Tvec vg_bf = inv_rho_g * acc_rhov_new_patch[id_a];
263 Tvec vg_af = inv_rho_g * acc_rho_old[id_a] * tmp_vel;
264 ;
265 Tscal work_drag
266 = 0.5
267 * ((acc_rho_old[id_a] * tmp_vel[0] - acc_rhov_new_patch[id_a][0])
268 * (vg_bf[0] + vg_af[0])
269 + (acc_rho_old[id_a] * tmp_vel[1] - acc_rhov_new_patch[id_a][1])
270 * (vg_bf[1] + vg_af[1])
271 + (acc_rho_old[id_a] * tmp_vel[2] - acc_rhov_new_patch[id_a][2])
272 * (vg_bf[2] + vg_af[2]));
273 Tscal dissipation = 0.0;
274 for (u32 i = 0; i < ndust; i++) {
275 const Tscal inv_dt_alphas = 1.0 / (1.0 + acc_alphas[i] * dt);
276 const Tscal dt_alphas = dt * acc_alphas[i];
277 Tscal inv_rho_d = 1.0 / acc_rho_d_new_patch[id_a * ndust + i];
278 Tvec vd_bf = inv_rho_d * acc_rhov_d_new_patch[id_a * ndust + i];
279 Tvec vd_af = inv_rho_d * inv_dt_alphas
280 * (acc_rhov_d_new_patch[id_a * ndust + i]
281 + dt_alphas * acc_rho_d_old[id_a * ndust + i] * tmp_vel);
282 dissipation += 0.5 * dt_alphas * inv_dt_alphas
283 * ((acc_rho_d_old[id_a * ndust + i] * tmp_vel[0]
284 - acc_rhov_d_new_patch[id_a * ndust + i][0])
285 * (vd_af[0] + vd_bf[0])
286 + (acc_rho_d_old[id_a * ndust + i] * tmp_vel[1]
287 - acc_rhov_d_new_patch[id_a * ndust + i][1])
288 * (vd_af[1] + vd_bf[1])
289 + (acc_rho_d_old[id_a * ndust + i] * tmp_vel[2]
290 - acc_rhov_d_new_patch[id_a * ndust + i][2])
291 * (vd_af[2] + vd_bf[2]));
292 }
293
294 Eg += acc_rhoe_new_patch[id_a] + (1 - friction_control) * work_drag
295 - friction_control * dissipation;
296 acc_rhov_old[id_a] = tmp_vel * acc_rho_old[id_a];
297 acc_rhoe_old[id_a] = Eg;
298 acc_rho_old[id_a] = acc_rho_new_patch[id_a];
299 for (u32 i = 0; i < ndust; i++) {
300 const Tscal inv_dt_alphas = 1.0 / (1.0 + acc_alphas[i] * dt);
301 const Tscal dt_alphas = dt * acc_alphas[i];
302 acc_rhov_d_old[id_a * ndust + i]
303 = inv_dt_alphas
304 * (acc_rhov_d_new_patch[id_a * ndust + i]
305 + dt_alphas * acc_rho_d_old[id_a * ndust + i] * tmp_vel);
306 acc_rho_d_old[id_a * ndust + i] = acc_rho_d_new_patch[id_a * ndust + i];
307 }
308 });
309 });
310
311 rho_new_patch.complete_event_state(e);
312 rhov_new_patch.complete_event_state(e);
313 rhoe_new_patch.complete_event_state(e);
314 rho_d_new_patch.complete_event_state(e);
315 rhov_d_new_patch.complete_event_state(e);
316
317 rho_old.complete_event_state(e);
318 rhov_old.complete_event_state(e);
319 rhoe_old.complete_event_state(e);
320 rho_d_old.complete_event_state(e);
321 rhov_d_old.complete_event_state(e);
322
323 alphas_buf.complete_event_state(e);
324 });
325}
326
327template<class Tvec, class TgridVec>
329 Tscal dt) {
330 StackEntry stack_lock{};
331
332 using namespace shamrock::patch;
333 using namespace shamrock;
334 using namespace shammath;
335
336 shamrock::ComputeField<Tscal> &cfield_rho_new = storage.rho_next_no_drag.get();
337 shamrock::ComputeField<Tvec> &cfield_rhov_new = storage.rhov_next_no_drag.get();
338 shamrock::ComputeField<Tscal> &cfield_rhoe_new = storage.rhoe_next_no_drag.get();
339 shamrock::ComputeField<Tscal> &cfield_rho_d_new = storage.rho_d_next_no_drag.get();
340 shamrock::ComputeField<Tvec> &cfield_rhov_d_new = storage.rhov_d_next_no_drag.get();
341
342 // load layout info
343 PatchDataLayerLayout &pdl = scheduler().pdl_old();
344
345 const u32 icell_min = pdl.get_field_idx<TgridVec>("cell_min");
346 const u32 icell_max = pdl.get_field_idx<TgridVec>("cell_max");
347 const u32 irho = pdl.get_field_idx<Tscal>("rho");
348 const u32 irhoetot = pdl.get_field_idx<Tscal>("rhoetot");
349 const u32 irhovel = pdl.get_field_idx<Tvec>("rhovel");
350 const u32 irho_d = pdl.get_field_idx<Tscal>("rho_dust");
351 const u32 irhovel_d = pdl.get_field_idx<Tvec>("rhovel_dust");
352
353 const u32 ndust = solver_config.dust_config.ndust;
354
355 // alphas are dust collision rates
356 auto alphas_vector = solver_config.drag_config.alphas;
357 std::vector<Tscal> inv_dt_alphas(ndust);
358 bool enable_frictional_heating = solver_config.drag_config.enable_frictional_heating;
359 u32 friction_control = (enable_frictional_heating == false) ? 1 : 0;
360
361 scheduler().for_each_patchdata_nonempty([&, dt, ndust, friction_control](
364 shamlog_debug_ln("[Ramses]", "expo drag on patch", p.id_patch);
365
366 sham::DeviceQueue &q = shamsys::instance::get_compute_scheduler().get_queue();
367 u32 id = p.id_patch;
368 u32 cell_count = pdat.get_obj_cnt() * AMRBlock::block_size;
369
370 sham::DeviceBuffer<Tscal> &rho_new_patch = cfield_rho_new.get_buf_check(id);
371 sham::DeviceBuffer<Tvec> &rhov_new_patch = cfield_rhov_new.get_buf_check(id);
372 sham::DeviceBuffer<Tscal> &rhoe_new_patch = cfield_rhoe_new.get_buf_check(id);
373 sham::DeviceBuffer<Tscal> &rho_d_new_patch = cfield_rho_d_new.get_buf_check(id);
374 sham::DeviceBuffer<Tvec> &rhov_d_new_patch = cfield_rhov_d_new.get_buf_check(id);
375
376 sham::DeviceBuffer<Tscal> &rho_old = pdat.get_field_buf_ref<Tscal>(irho);
377 sham::DeviceBuffer<Tvec> &rhov_old = pdat.get_field_buf_ref<Tvec>(irhovel);
378 sham::DeviceBuffer<Tscal> &rhoe_old = pdat.get_field_buf_ref<Tscal>(irhoetot);
379 sham::DeviceBuffer<Tscal> &rho_d_old = pdat.get_field_buf_ref<Tscal>(irho_d);
380 sham::DeviceBuffer<Tvec> &rhov_d_old = pdat.get_field_buf_ref<Tvec>(irhovel_d);
381
382 sham::DeviceBuffer<Tscal> alphas_buf(ndust, shamsys::instance::get_compute_scheduler_ptr());
383
384 alphas_buf.copy_from_stdvec(alphas_vector);
385
386 sham::EventList depend_list;
387 auto acc_rho_new_patch = rho_new_patch.get_read_access(depend_list);
388 auto acc_rhov_new_patch = rhov_new_patch.get_read_access(depend_list);
389 auto acc_rhoe_new_patch = rhoe_new_patch.get_read_access(depend_list);
390 auto acc_rho_d_new_patch = rho_d_new_patch.get_read_access(depend_list);
391 auto acc_rhov_d_new_patch = rhov_d_new_patch.get_read_access(depend_list);
392
393 auto acc_rho_old = rho_old.get_write_access(depend_list);
394 auto acc_rhov_old = rhov_old.get_write_access(depend_list);
395 auto acc_rhoe_old = rhoe_old.get_write_access(depend_list);
396 auto acc_rho_d_old = rho_d_old.get_write_access(depend_list);
397 auto acc_rhov_d_old = rhov_d_old.get_write_access(depend_list);
398
399 auto acc_alphas = alphas_buf.get_read_access(depend_list);
400
401 size_t mat_size = ndust + 1;
402 size_t mat_size_squared = mat_size * mat_size;
403 size_t group_size
404 = (q.get_device_prop().local_mem_size) / (5 * mat_size_squared * sizeof(Tscal));
405 size_t loc_acc_size = mat_size_squared * group_size;
406
407 size_t loc_mem_size = 5 * sizeof(Tscal) * loc_acc_size;
408
409 if (group_size < 8) {
410 sham::DeviceBuffer<Tscal> scratch_expo(
411 5 * mat_size_squared * cell_count, shamsys::instance::get_compute_scheduler_ptr());
412 Tscal *exp_scratch_ptr_base = scratch_expo.get_write_access(depend_list);
413 auto e = q.submit(depend_list, [&, dt, ndust, friction_control](sycl::handler &cgh) {
414 shambase::parallel_for(
415 cgh, cell_count, "add_drag [expo-global-mem]", [=](u32 id_a) {
416 // sparse jacobian matrix
417 auto get_jacobian =
418 [=](u32 id,
419 std::mdspan<
420 Tscal,
421 std::extents<size_t, std::dynamic_extent, std::dynamic_extent>>
422 &jacobian) {
423 mat_set_nul<Tscal>(jacobian);
424 // fill first row
425 for (auto j = 1; j < jacobian.extent(1); j++)
426 jacobian(0, j) = acc_alphas[j - 1];
427 // fil first column
428 for (auto i = 1; i < jacobian.extent(0); i++) {
429 jacobian(i, 0) = acc_alphas[i - 1]
430 * (acc_rho_d_new_patch[id * ndust + (i - 1)]
431 / acc_rho_new_patch[id]);
432 jacobian(0, 0) -= jacobian(i, 0);
433 }
434 // fill diagonal from (i,j)=(1,1)
435 for (auto i = 1; i < jacobian.extent(0); i++)
436 jacobian(i, i) = -acc_alphas[i - 1];
437 // the rest of the buffer is set to zero
438 };
439 Tscal mu = 0;
440 for (auto i = 0; i < ndust; i++) {
441 mu += (1
442 + (acc_rho_d_new_patch[id_a * ndust + i]
443 / acc_rho_new_patch[id_a]))
444 * acc_alphas[i];
445 }
446 mu *= (-dt / (ndust + 1));
447
448 // get ptr to datas
449 Tscal *ptr_A = exp_scratch_ptr_base + (id_a * 5 * mat_size_squared);
450 Tscal *ptr_B = exp_scratch_ptr_base + (id_a * 5 * mat_size_squared)
451 + mat_size_squared;
452 Tscal *ptr_F = exp_scratch_ptr_base + (id_a * 5 * mat_size_squared)
453 + 2 * mat_size_squared;
454 Tscal *ptr_I = exp_scratch_ptr_base + (id_a * 5 * mat_size_squared)
455 + 3 * mat_size_squared;
456 Tscal *ptr_Id = exp_scratch_ptr_base + (id_a * 5 * mat_size_squared)
457 + 4 * mat_size_squared;
458
459 // create mdspan(s)
460 std::mdspan<
461 Tscal,
462 std::extents<size_t, std::dynamic_extent, std::dynamic_extent>>
463 mdspan_A(ptr_A, mat_size, mat_size);
464 std::mdspan<
465 Tscal,
466 std::extents<size_t, std::dynamic_extent, std::dynamic_extent>>
467 mdspan_B(ptr_B, mat_size, mat_size);
468 std::mdspan<
469 Tscal,
470 std::extents<size_t, std::dynamic_extent, std::dynamic_extent>>
471 mdspan_F(ptr_F, mat_size, mat_size);
472 std::mdspan<
473 Tscal,
474 std::extents<size_t, std::dynamic_extent, std::dynamic_extent>>
475 mdspan_I(ptr_I, mat_size, mat_size);
476 std::mdspan<
477 Tscal,
478 std::extents<size_t, std::dynamic_extent, std::dynamic_extent>>
479 mdspan_Id(ptr_Id, mat_size, mat_size);
480
481 get_jacobian(id_a, mdspan_A);
482
483 // pre-processing step
484 shammath::mat_set_identity<Tscal>(mdspan_Id);
485 shammath::mat_axpy_beta<Tscal, Tscal>(-mu, mdspan_Id, dt, mdspan_A);
486
487 // compute matrix exponential
488 const i32 K_exp = 9;
489 shammath::mat_exp<Tscal, Tscal>(
490 K_exp, mdspan_A, mdspan_F, mdspan_B, mdspan_I, mdspan_Id, ndust + 1);
491
492 // post-processing step
493 shammath::mat_mul_scalar<Tscal>(mdspan_A, sycl::exp(mu));
494
495 // use the matrix exponential to for to updates momemtum
496 Tvec r = {0., 0., 0.}, dd = {0., 0., 0.};
497 r += mdspan_A(0, 0) * acc_rhov_new_patch[id_a];
498
499 for (auto j = 1; j < ndust + 1; j++) {
500 r += mdspan_A(0, j) * acc_rhov_d_new_patch[id_a * ndust + (j - 1)];
501 }
502
503 dd = r - acc_rhov_new_patch[id_a];
504
505 Tscal dissipation = 0, drag_work = 0;
506
507 // compute work of drag terms
508 Tscal inv_rho = 1.0 / (acc_rho_new_patch[id_a]);
509
510 Tvec v_bf = inv_rho * acc_rhov_new_patch[id_a];
511 Tvec v_af = inv_rho * r;
512
513 drag_work = 0.5
514 * (dd[0] * (v_bf[0] + v_af[0]) + dd[1] * (v_bf[1] + v_af[1])
515 + dd[2] * (v_bf[2] + v_af[2]));
516
517 // save gas momentum back
518 acc_rhov_old[id_a] = r;
519 acc_rho_old[id_a] = acc_rho_new_patch[id_a];
520
521 for (auto d_id = 1; d_id <= ndust; d_id++) {
522 r *= 0;
523 r += mdspan_A(d_id, 0) * acc_rhov_new_patch[id_a];
524
525 for (auto j = 1; j <= ndust; j++) {
526
527 r += mdspan_A(d_id, j)
528 * acc_rhov_d_new_patch[id_a * ndust + (j - 1)];
529 }
530
531 dd = r - acc_rhov_d_new_patch[id_a * ndust + (d_id - 1)];
532
533 inv_rho = 1.0 / (acc_rho_d_new_patch[id_a * ndust + (d_id - 1)]);
534
535 v_bf = inv_rho * acc_rhov_d_new_patch[id_a * ndust + (d_id - 1)];
536
537 v_af = inv_rho * r;
538
539 // compute dissipaation by id-th dust
540 dissipation
541 += 0.5
542 * (dd[0] * (v_bf[0] + v_af[0]) + dd[1] * (v_bf[1] + v_af[1])
543 + dd[2] * (v_bf[2] + v_af[2]));
544
545 // save dust momentum back
546 acc_rhov_d_old[id_a * ndust + (d_id - 1)] = r;
547 acc_rho_d_old[id_a * ndust + (d_id - 1)]
548 = acc_rho_d_new_patch[id_a * ndust + (d_id - 1)];
549 }
550
551 // updates energy
552 acc_rhoe_old[id_a] = acc_rhoe_new_patch[id_a]
553 + (1 - friction_control) * drag_work
554 - friction_control * dissipation;
555 });
556 });
557
558 rho_new_patch.complete_event_state(e);
559 rhov_new_patch.complete_event_state(e);
560 rhoe_new_patch.complete_event_state(e);
561 rho_d_new_patch.complete_event_state(e);
562 rhov_d_new_patch.complete_event_state(e);
563
564 rho_old.complete_event_state(e);
565 rhov_old.complete_event_state(e);
566 rhoe_old.complete_event_state(e);
567 rho_d_old.complete_event_state(e);
568 rhov_d_old.complete_event_state(e);
569
570 alphas_buf.complete_event_state(e);
571 scratch_expo.complete_event_state(e);
572
573 } else {
574
575 if (loc_mem_size > q.get_device_prop().local_mem_size) {
577 "not enough local memory for expo drag integrator:\n"
578 "loc_mem_size: {} > max_local_mem: {}\n"
579 "loc_acc_size: {}\n"
580 "group_size: {}\n"
581 "ndust: {}\n",
582 loc_mem_size,
584 loc_acc_size,
585 group_size,
586 ndust));
587 }
588
589 auto e = q.submit(depend_list, [&, dt, ndust, friction_control](sycl::handler &cgh) {
590 // local/shared memory alloc for each work-item
591 sycl::local_accessor<Tscal> local_A(loc_acc_size, cgh);
592 sycl::local_accessor<Tscal> local_B(loc_acc_size, cgh);
593 sycl::local_accessor<Tscal> local_F(loc_acc_size, cgh);
594 sycl::local_accessor<Tscal> local_I(loc_acc_size, cgh);
595 sycl::local_accessor<Tscal> local_Id(loc_acc_size, cgh);
596
597 logger::debug_sycl_ln(
598 "SYCL", shambase::format("parallel_for add_drag [expo-shared-mem]"));
599 cgh.parallel_for(
600 shambase::make_range(cell_count, group_size), [=](sycl::nd_item<1> id) {
601 u32 loc_id = id.get_local_id();
602 u32 id_a = id.get_global_id();
603 if (id_a >= cell_count)
604 return;
605
606 // sparse jacobian matrix
607 auto get_jacobian =
608 [=](u32 id,
609 std::mdspan<
610 Tscal,
611 std::extents<size_t, std::dynamic_extent, std::dynamic_extent>>
612 &jacobian) {
613 mat_set_nul<Tscal>(jacobian);
614 // fill first row
615 for (auto j = 1; j < jacobian.extent(1); j++)
616 jacobian(0, j) = acc_alphas[j - 1];
617 // fil first column
618 for (auto i = 1; i < jacobian.extent(0); i++) {
619 jacobian(i, 0) = acc_alphas[i - 1]
620 * (acc_rho_d_new_patch[id * ndust + (i - 1)]
621 / acc_rho_new_patch[id]);
622 jacobian(0, 0) -= jacobian(i, 0);
623 }
624 // fill diagonal from (i,j)=(1,1)
625 for (auto i = 1; i < jacobian.extent(0); i++)
626 jacobian(i, i) = -acc_alphas[i - 1];
627 // the rest of the buffer is set to zero
628 };
629
630 Tscal mu = 0;
631 for (auto i = 0; i < ndust; i++) {
632 mu += (1
633 + (acc_rho_d_new_patch[id_a * ndust + i]
634 / acc_rho_new_patch[id_a]))
635 * acc_alphas[i];
636 }
637 mu *= (-dt / (ndust + 1));
638
639 // get ptr to datas
640 Tscal *ptr_loc_A = &(local_A[0]) + mat_size_squared * loc_id;
641 Tscal *ptr_loc_B = &(local_B[0]) + mat_size_squared * loc_id;
642 Tscal *ptr_loc_F = &(local_F[0]) + mat_size_squared * loc_id;
643 Tscal *ptr_loc_I = &(local_I[0]) + mat_size_squared * loc_id;
644 Tscal *ptr_loc_Id = &(local_Id[0]) + mat_size_squared * loc_id;
645
646 // create mdspan(s)
647 std::mdspan<
648 Tscal,
649 std::extents<size_t, std::dynamic_extent, std::dynamic_extent>>
650 mdspan_A(ptr_loc_A, mat_size, mat_size);
651 std::mdspan<
652 Tscal,
653 std::extents<size_t, std::dynamic_extent, std::dynamic_extent>>
654 mdspan_B(ptr_loc_B, mat_size, mat_size);
655 std::mdspan<
656 Tscal,
657 std::extents<size_t, std::dynamic_extent, std::dynamic_extent>>
658 mdspan_F(ptr_loc_F, mat_size, mat_size);
659 std::mdspan<
660 Tscal,
661 std::extents<size_t, std::dynamic_extent, std::dynamic_extent>>
662 mdspan_I(ptr_loc_I, mat_size, mat_size);
663 std::mdspan<
664 Tscal,
665 std::extents<size_t, std::dynamic_extent, std::dynamic_extent>>
666 mdspan_Id(ptr_loc_Id, mat_size, mat_size);
667
668 // get local Jacobian matrix
669
670 get_jacobian(id_a, mdspan_A);
671
672 // pre-processing step
673 shammath::mat_set_identity<Tscal>(mdspan_Id);
674 shammath::mat_axpy_beta<Tscal, Tscal>(-mu, mdspan_Id, dt, mdspan_A);
675
676 // compute matrix exponential
677 const i32 K_exp = 9;
678 shammath::mat_exp<Tscal, Tscal>(
679 K_exp, mdspan_A, mdspan_F, mdspan_B, mdspan_I, mdspan_Id, ndust + 1);
680
681 // post-processing step
682 shammath::mat_mul_scalar<Tscal>(mdspan_A, sycl::exp(mu));
683
684 // use the matrix exponential to for to updates momemtum
685 Tvec r = {0., 0., 0.}, dd = {0., 0., 0.};
686 r += mdspan_A(0, 0) * acc_rhov_new_patch[id_a];
687
688 for (auto j = 1; j < ndust + 1; j++) {
689 r += mdspan_A(0, j) * acc_rhov_d_new_patch[id_a * ndust + (j - 1)];
690 }
691
692 dd = r - acc_rhov_new_patch[id_a];
693
694 Tscal dissipation = 0, drag_work = 0;
695
696 // compute work of drag terms
697 Tscal inv_rho = 1.0 / (acc_rho_new_patch[id_a]);
698
699 Tvec v_bf = inv_rho * acc_rhov_new_patch[id_a];
700 Tvec v_af = inv_rho * r;
701
702 drag_work = 0.5
703 * (dd[0] * (v_bf[0] + v_af[0]) + dd[1] * (v_bf[1] + v_af[1])
704 + dd[2] * (v_bf[2] + v_af[2]));
705
706 // save gas momentum back
707 acc_rhov_old[id_a] = r;
708 acc_rho_old[id_a] = acc_rho_new_patch[id_a];
709
710 for (auto d_id = 1; d_id <= ndust; d_id++) {
711 r *= 0;
712 r += mdspan_A(d_id, 0) * acc_rhov_new_patch[id_a];
713
714 for (auto j = 1; j <= ndust; j++) {
715
716 r += mdspan_A(d_id, j)
717 * acc_rhov_d_new_patch[id_a * ndust + (j - 1)];
718 }
719
720 dd = r - acc_rhov_d_new_patch[id_a * ndust + (d_id - 1)];
721
722 inv_rho = 1.0 / (acc_rho_d_new_patch[id_a * ndust + (d_id - 1)]);
723
724 v_bf = inv_rho * acc_rhov_d_new_patch[id_a * ndust + (d_id - 1)];
725
726 v_af = inv_rho * r;
727
728 // compute dissipaation by id-th dust
729 dissipation
730 += 0.5
731 * (dd[0] * (v_bf[0] + v_af[0]) + dd[1] * (v_bf[1] + v_af[1])
732 + dd[2] * (v_bf[2] + v_af[2]));
733
734 // save dust momentum back
735 acc_rhov_d_old[id_a * ndust + (d_id - 1)] = r;
736 acc_rho_d_old[id_a * ndust + (d_id - 1)]
737 = acc_rho_d_new_patch[id_a * ndust + (d_id - 1)];
738 }
739
740 // updates energy
741 acc_rhoe_old[id_a] = acc_rhoe_new_patch[id_a]
742 + (1 - friction_control) * drag_work
743 - friction_control * dissipation;
744 });
745 });
746
747 rho_new_patch.complete_event_state(e);
748 rhov_new_patch.complete_event_state(e);
749 rhoe_new_patch.complete_event_state(e);
750 rho_d_new_patch.complete_event_state(e);
751 rhov_d_new_patch.complete_event_state(e);
752
753 rho_old.complete_event_state(e);
754 rhov_old.complete_event_state(e);
755 rhoe_old.complete_event_state(e);
756 rho_d_old.complete_event_state(e);
757 rhov_d_old.complete_event_state(e);
758
759 alphas_buf.complete_event_state(e);
760 }
761 });
762}
Header file describing a Node Instance.
std::uint32_t u32
32 bit unsigned integer
std::int32_t i32
32 bit integer
A buffer allocated in USM (Unified Shared Memory)
void complete_event_state(sycl::event e) const
Complete the event state of the buffer.
T * get_write_access(sham::EventList &depends_list, SourceLocation src_loc=SourceLocation{})
Get a read-write pointer to the buffer's data.
const T * get_read_access(sham::EventList &depends_list, SourceLocation src_loc=SourceLocation{}) const
Get a read-only pointer to the buffer's data.
A SYCL queue associated with a device and a context.
DeviceProperties & get_device_prop()
Retrieves the properties of the associated device.
sycl::event submit(Fct &&fct)
Submits a kernel to the SYCL queue.
DeviceQueue & get_queue(u32 id=0)
Get a reference to a DeviceQueue.
Class to manage a list of SYCL events.
Definition EventList.hpp:31
u32 get_field_idx(const std::string &field_name) const
Get the field id if matching name & type.
PatchDataLayer container class, the layout is described in patchdata_layout.
This header file contains utility functions related to exception handling in the code.
void throw_with_loc(std::string message, SourceLocation loc=SourceLocation{})
Throw an exception and append the source location to it.
sycl::nd_range< 1 > make_range(u32 length, const u32 group_size=32)
Generate a sycl nd range out of a group size and length.
T & get_check_ref(const std::unique_ptr< T > &ptr, SourceLocation loc=SourceLocation())
Takes a std::unique_ptr and returns a reference to the object it holds. It throws a std::runtime_erro...
Definition memory.hpp:110
namespace for math utility
Definition AABB.hpp:26
namespace for the main framework
Definition __init__.py:1
usize local_mem_size
The amount of shared local memory on the device in bytes.
Definition Device.hpp:110
Patch object that contain generic patch information.
Definition Patch.hpp:33