Shamrock 2025.10.0
Astrophysical Code
Loading...
Searching...
No Matches
ComputeCFL.cpp
Go to the documentation of this file.
1// -------------------------------------------------------//
2//
3// SHAMROCK code for hydrodynamics
4// Copyright (c) 2021-2026 Timothée David--Cléris <tim.shamrock@proton.me>
5// SPDX-License-Identifier: CeCILL Free Software License Agreement v2.1
6// Shamrock is licensed under the CeCILL 2.1 License, see LICENSE for more information
7//
8// -------------------------------------------------------//
9
18#include "fmt/core.h"
19#include "shammath/riemann.hpp"
22
23template<class Tvec, class TgridVec>
25
26 StackEntry stack_loc{};
27
28 using namespace shamrock::patch;
29 using namespace shamrock;
30 using namespace shammath;
31
32 SchedulerUtility utility(scheduler());
33 ComputeField<Tscal> cfl_dt = utility.make_compute_field<Tscal>("cfl_dt", AMRBlock::block_size);
34
35 // load layout info
36 PatchDataLayerLayout &pdl = scheduler().pdl_old();
37
38 const u32 icell_min = pdl.get_field_idx<TgridVec>("cell_min");
39 const u32 icell_max = pdl.get_field_idx<TgridVec>("cell_max");
40 const u32 irho = pdl.get_field_idx<Tscal>("rho");
41 const u32 irhoetot = pdl.get_field_idx<Tscal>("rhoetot");
42 const u32 irhovel = pdl.get_field_idx<Tvec>("rhovel");
43
44 scheduler().for_each_patchdata_nonempty([&](Patch cur_p, PatchDataLayer &pdat) {
45 sham::DeviceQueue &q = shamsys::instance::get_compute_scheduler().get_queue();
46
47 u32 cell_count = pdat.get_obj_cnt() * AMRBlock::block_size;
48
49 sham::DeviceBuffer<TgridVec> &buf_block_min = pdat.get_field_buf_ref<TgridVec>(0);
50 sham::DeviceBuffer<TgridVec> &buf_block_max = pdat.get_field_buf_ref<TgridVec>(1);
51
52 sham::DeviceBuffer<Tscal> &buf_rho = pdat.get_field_buf_ref<Tscal>(irho);
53 sham::DeviceBuffer<Tvec> &buf_rhov = pdat.get_field_buf_ref<Tvec>(irhovel);
54 sham::DeviceBuffer<Tscal> &buf_rhoe = pdat.get_field_buf_ref<Tscal>(irhoetot);
55
56 sham::DeviceBuffer<Tscal> &cfl_dt_buf = cfl_dt.get_buf_check(cur_p.id_patch);
57
58 sham::DeviceBuffer<Tscal> &block_cell_sizes
59 = shambase::get_check_ref(storage.block_cell_sizes)
60 .get_refs()
61 .get(cur_p.id_patch)
62 .get()
63 .get_buf();
64
65 sham::EventList depends_list;
66 auto cfl_dt = cfl_dt_buf.get_write_access(depends_list);
67 auto acc_block_min = buf_block_min.get_read_access(depends_list);
68 auto acc_block_max = buf_block_max.get_read_access(depends_list);
69 auto rho = buf_rho.get_read_access(depends_list);
70 auto rhov = buf_rhov.get_read_access(depends_list);
71 auto rhoe = buf_rhoe.get_read_access(depends_list);
72
73 auto e = q.submit(depends_list, [&](sycl::handler &cgh) {
74 Tscal C_safe = solver_config.Csafe;
75 Tscal gamma = solver_config.eos_gamma;
76
77 Tscal one_over_Nside = 1. / AMRBlock::Nside;
78
79 Tscal dxfact = solver_config.grid_coord_to_pos_fact;
80 shambase::parallel_for(cgh, cell_count, "compute_cfl", [=](u64 gid) {
81 const u32 cell_global_id = (u32) gid;
82
83 const u32 block_id = cell_global_id / AMRBlock::block_size;
84 const u32 cell_loc_id = cell_global_id % AMRBlock::block_size;
85
86 TgridVec lower = acc_block_min[block_id];
87 TgridVec upper = acc_block_max[block_id];
88 Tvec lower_flt = lower.template convert<Tscal>() * dxfact;
89 Tvec upper_flt = upper.template convert<Tscal>() * dxfact;
90 Tvec block_cell_size = (upper_flt - lower_flt) * one_over_Nside;
91 Tscal dx = block_cell_size.x();
92
93 auto conststate = shammath::ConsState<Tvec>{rho[gid], rhoe[gid], rhov[gid]};
94
95 auto prim_state = shammath::cons_to_prim(conststate, gamma);
96
97 constexpr Tscal div = 1. / 3.;
98
99 Tscal cs = sound_speed(prim_state, gamma);
100 Tscal vnorm = sycl::length(prim_state.vel);
101 Tscal dt = C_safe * dx * div / (cs + vnorm);
102
103 cfl_dt[gid] = dt;
104 });
105 });
106
107 cfl_dt_buf.complete_event_state(e);
108 buf_block_min.complete_event_state(e);
109 buf_block_max.complete_event_state(e);
110 buf_rho.complete_event_state(e);
111 buf_rhov.complete_event_state(e);
112 buf_rhoe.complete_event_state(e);
113 });
114
115 Tscal rank_dt = cfl_dt.compute_rank_min();
116
117 shamlog_debug_ln("basegodunov", "rank", shamcomm::world_rank(), "found cfl dt =", rank_dt);
118
119 Tscal next_cfl = shamalgs::collective::allreduce_min(rank_dt);
120
121 if (shamcomm::world_rank() == 0) {
122 logger::info_ln("amr::basegodunov", "cfl dt =", next_cfl);
123 }
124
125 return next_cfl;
126}
127
128template<class Tvec, class TgridVec>
130
131 StackEntry stack_loc{};
132
133 using namespace shamrock::patch;
134 using namespace shamrock;
135 using namespace shammath;
136
137 SchedulerUtility utility(scheduler());
138 u32 ndust = solver_config.dust_config.ndust;
139 ComputeField<Tscal> dust_cfl_dt
140 = utility.make_compute_field<Tscal>("dust_cfl_dt", ndust * AMRBlock::block_size);
141
142 // load layout info
143 PatchDataLayerLayout &pdl = scheduler().pdl_old();
144
145 const u32 icell_min = pdl.get_field_idx<TgridVec>("cell_min");
146 const u32 icell_max = pdl.get_field_idx<TgridVec>("cell_max");
147 const u32 irho_dust = pdl.get_field_idx<Tscal>("rho_dust");
148 const u32 irhovel_dust = pdl.get_field_idx<Tvec>("rhovel_dust");
149
150 scheduler().for_each_patchdata_nonempty([&](Patch cur_p, PatchDataLayer &pdat) {
151 sham::DeviceQueue &q = shamsys::instance::get_compute_scheduler().get_queue();
152
153 u32 cell_count = pdat.get_obj_cnt() * AMRBlock::block_size;
154
155 sham::DeviceBuffer<Tscal> &buf_rho_dust = pdat.get_field_buf_ref<Tscal>(irho_dust);
156 sham::DeviceBuffer<Tvec> &buf_rhov_dust = pdat.get_field_buf_ref<Tvec>(irhovel_dust);
157
158 sham::DeviceBuffer<Tscal> &dust_cfl_dt_buf = dust_cfl_dt.get_buf_check(cur_p.id_patch);
159
160 sham::DeviceBuffer<Tscal> &block_cell_sizes
161 = shambase::get_check_ref(storage.block_cell_sizes)
162 .get_refs()
163 .get(cur_p.id_patch)
164 .get()
165 .get_buf();
166
167 sham::EventList depends_list;
168 auto dust_cfl_dt = dust_cfl_dt_buf.get_write_access(depends_list);
169 auto rho_dust = buf_rho_dust.get_read_access(depends_list);
170 auto rhov_dust = buf_rhov_dust.get_read_access(depends_list);
171 auto acc_aabb_cell_size = block_cell_sizes.get_read_access(depends_list);
172
173 auto e = q.submit(depends_list, [&](sycl::handler &cgh) {
174 Tscal C_safe = solver_config.Csafe;
175
176 shambase::parallel_for(cgh, ndust * cell_count, "compute_dust_cfl", [=](u64 gid) {
177 const u32 tmp_gid = (u32) gid;
178 const u32 cell_global_id = tmp_gid / ndust;
179 const u32 ndust_off_loc = tmp_gid % ndust;
180
181 const u32 block_id = cell_global_id / AMRBlock::block_size;
182 const u32 cell_loc_id = cell_global_id % AMRBlock::block_size;
183
184 auto conststate = shammath::DustConsState<Tvec>{
185 rho_dust[ndust * cell_global_id + ndust_off_loc],
186 rhov_dust[ndust * cell_global_id + ndust_off_loc]};
187 Tscal dx = acc_aabb_cell_size[block_id];
188
189 auto prim_state = shammath::d_cons_to_prim(conststate);
190
191 constexpr Tscal div = 1. / 3.;
192
193 Tscal vnorm = sycl::length(prim_state.vel);
194 Tscal dt = C_safe * dx * div / (vnorm);
195
196 dust_cfl_dt[ndust * cell_global_id + ndust_off_loc] = dt;
197 });
198 });
199
200 dust_cfl_dt_buf.complete_event_state(e);
201 buf_rho_dust.complete_event_state(e);
202 buf_rhov_dust.complete_event_state(e);
203 block_cell_sizes.complete_event_state(e);
204 });
205
206 Tscal rank_dust_dt = dust_cfl_dt.compute_rank_min();
207
208 shamlog_debug_ln(
209 "basegodunov", "rank", shamcomm::world_rank(), "found dust cfl dt =", rank_dust_dt);
210
211 Tscal next_dust_cfl = shamalgs::collective::allreduce_min(rank_dust_dt);
212
213 if (shamcomm::world_rank() == 0) {
214 logger::info_ln("amr::basegodunov", "dust cfl dt =", next_dust_cfl);
215 }
216
217 return next_dust_cfl;
218}
std::uint32_t u32
32 bit unsigned integer
std::uint64_t u64
64 bit unsigned integer
A buffer allocated in USM (Unified Shared Memory)
void complete_event_state(sycl::event e) const
Complete the event state of the buffer.
const T * get_read_access(sham::EventList &depends_list, SourceLocation src_loc=SourceLocation{}) const
Get a read-only pointer to the buffer's data.
A SYCL queue associated with a device and a context.
sycl::event submit(Fct &&fct)
Submits a kernel to the SYCL queue.
DeviceQueue & get_queue(u32 id=0)
Get a reference to a DeviceQueue.
Class to manage a list of SYCL events.
Definition EventList.hpp:31
u32 get_field_idx(const std::string &field_name) const
Get the field id if matching name & type.
PatchDataLayer container class, the layout is described in patchdata_layout.
T & get_check_ref(const std::unique_ptr< T > &ptr, SourceLocation loc=SourceLocation())
Takes a std::unique_ptr and returns a reference to the object it holds. It throws a std::runtime_erro...
Definition memory.hpp:110
i32 world_rank()
Gives the rank of the current process in the MPI communicator.
Definition worldInfo.cpp:40
namespace for math utility
Definition AABB.hpp:26
namespace for the main framework
Definition __init__.py:1
From original version by Thomas Guillet (T.A.Guillet@exeter.ac.uk)
This file contain states and Riemann solvers for dust.
Patch object that contain generic patch information.
Definition Patch.hpp:33
u64 id_patch
unique key that identify the patch
Definition Patch.hpp:86