Shamrock 2025.10.0
Astrophysical Code
Loading...
Searching...
No Matches
ComputeEos.cpp
Go to the documentation of this file.
1// -------------------------------------------------------//
2//
3// SHAMROCK code for hydrodynamics
4// Copyright (c) 2021-2026 Timothée David--Cléris <tim.shamrock@proton.me>
5// SPDX-License-Identifier: CeCILL Free Software License Agreement v2.1
6// Shamrock is licensed under the CeCILL 2.1 License, see LICENSE for more information
7//
8// -------------------------------------------------------//
9
25#include "shamphys/eos.hpp"
29
30template<class Tscal>
33 Tscal pmass;
34 Tscal hfact;
35
36 struct accessed {
37 const Tscal *h;
38 Tscal pmass;
39 Tscal hfact;
40
41 Tscal operator()(u32 i) const {
42 using namespace shamrock::sph;
43 return rho_h(pmass, h[i], hfact);
44 }
45 };
46
47 accessed get_read_access(sham::EventList &depends_list) {
48 auto h = buf_h.get_read_access(depends_list);
49 return accessed{h, pmass, hfact};
50 }
51
52 void complete_event_state(sycl::event e) { buf_h.complete_event_state(e); }
53};
54
55template<class Tscal>
58 sham::DeviceBuffer<Tscal> &buf_epsilon;
59 u32 nvar_dust;
60 Tscal pmass;
61 Tscal hfact;
62
63 struct accessed {
64 const Tscal *h;
65 const Tscal *buf_epsilon;
66 u32 nvar_dust;
67 Tscal pmass;
68 Tscal hfact;
69
70 Tscal operator()(u32 i) const {
71
72 Tscal epsilon_sum = 0;
73 for (u32 j = 0; j < nvar_dust; j++) {
74 epsilon_sum += buf_epsilon[i * nvar_dust + j];
75 }
76
77 using namespace shamrock::sph;
78 return (1 - epsilon_sum) * rho_h(pmass, h[i], hfact);
79 }
80 };
81
82 accessed get_read_access(sham::EventList &depends_list) {
83 auto h = buf_h.get_read_access(depends_list);
84 auto epsilon = buf_epsilon.get_read_access(depends_list);
85
86 return accessed{h, epsilon, nvar_dust, pmass, hfact};
87 }
88
89 void complete_event_state(sycl::event e) { buf_h.complete_event_state(e); }
90};
91
92template<class Tvec, template<class> class SPHKernel>
93template<class RhoGetGen>
95 RhoGetGen &&rho_getter_gen) {
96
98 = shambase::get_check_ref(storage.ghost_layout.get());
99 u32 iuint_interf = ghost_layout.get_field_idx<Tscal>("uint");
100
101 using namespace shamrock;
102 using namespace shamrock::patch;
103
104 using SolverConfigEOS = typename Config::EOSConfig;
105 using SolverEOS_Isothermal = typename SolverConfigEOS::Isothermal;
106 using SolverEOS_Adiabatic = typename SolverConfigEOS::Adiabatic;
107 using SolverEOS_Polytropic = typename SolverConfigEOS::Polytropic;
108 using SolverEOS_LocallyIsothermal = typename SolverConfigEOS::LocallyIsothermal;
109 using SolverEOS_LocallyIsothermalLP07 = typename SolverConfigEOS::LocallyIsothermalLP07;
110 using SolverEOS_LocallyIsothermalFA2014 = typename SolverConfigEOS::LocallyIsothermalFA2014;
111 using SolverEOS_LocallyIsothermalFA2014Extended =
112 typename SolverConfigEOS::LocallyIsothermalFA2014Extended;
113 using SolverEOS_Fermi = typename SolverConfigEOS::Fermi;
114
115 sham::DeviceQueue &q = shamsys::instance::get_compute_scheduler().get_queue();
116
117 if (SolverEOS_Isothermal *eos_config
118 = std::get_if<SolverEOS_Isothermal>(&solver_config.eos_config.config)) {
119
121
122 storage.merged_patchdata_ghost.get().for_each([&](u64 id, PatchDataLayer &mpdat) {
124 = shambase::get_check_ref(storage.pressure).get_field(id).get_buf();
126 = shambase::get_check_ref(storage.soundspeed).get_field(id).get_buf();
127 auto rho_getter = rho_getter_gen(mpdat);
128
129 u32 total_elements
130 = shambase::get_check_ref(storage.part_counts_with_ghost).indexes.get(id);
131
133 q,
134 sham::MultiRef{rho_getter},
135 sham::MultiRef{buf_P, buf_cs},
136 total_elements,
137 [cs_cfg
138 = eos_config->cs](u32 i, auto rho, Tscal *__restrict P, Tscal *__restrict cs) {
139 using namespace shamrock::sph;
140 Tscal rho_a = rho(i);
141 Tscal P_a = EOS::pressure(cs_cfg, rho_a);
142 P[i] = P_a;
143 cs[i] = cs_cfg;
144 });
145 });
146 } else if (
147 SolverEOS_Adiabatic *eos_config
148 = std::get_if<SolverEOS_Adiabatic>(&solver_config.eos_config.config)) {
149
151
152 storage.merged_patchdata_ghost.get().for_each([&](u64 id, PatchDataLayer &mpdat) {
154 = shambase::get_check_ref(storage.pressure).get_field(id).get_buf();
156 = shambase::get_check_ref(storage.soundspeed).get_field(id).get_buf();
157 sham::DeviceBuffer<Tscal> &buf_uint = mpdat.get_field_buf_ref<Tscal>(iuint_interf);
158 auto rho_getter = rho_getter_gen(mpdat);
159
160 u32 total_elements
161 = shambase::get_check_ref(storage.part_counts_with_ghost).indexes.get(id);
162
164 q,
165 sham::MultiRef{rho_getter, buf_uint},
166 sham::MultiRef{buf_P, buf_cs},
167 total_elements,
168 [gamma = eos_config->gamma](
169 u32 i,
170 auto rho,
171 const Tscal *__restrict U,
172 Tscal *__restrict P,
173 Tscal *__restrict cs) {
174 using namespace shamrock::sph;
175 Tscal rho_a = rho(i);
176 Tscal P_a = EOS::pressure(gamma, rho_a, U[i]);
177 Tscal cs_a = EOS::cs_from_p(gamma, rho_a, P_a);
178 P[i] = P_a;
179 cs[i] = cs_a;
180 });
181 });
182
183 } else if (
184 SolverEOS_Polytropic *eos_config
185 = std::get_if<SolverEOS_Polytropic>(&solver_config.eos_config.config)) {
186
188
189 storage.merged_patchdata_ghost.get().for_each([&](u64 id, PatchDataLayer &mpdat) {
191 = shambase::get_check_ref(storage.pressure).get_field(id).get_buf();
193 = shambase::get_check_ref(storage.soundspeed).get_field(id).get_buf();
194 auto rho_getter = rho_getter_gen(mpdat);
195
196 u32 total_elements
197 = shambase::get_check_ref(storage.part_counts_with_ghost).indexes.get(id);
198
200 q,
201 sham::MultiRef{rho_getter},
202 sham::MultiRef{buf_P, buf_cs},
203 total_elements,
204 [K = eos_config->K, gamma = eos_config->gamma](
205 u32 i, auto rho, Tscal *__restrict P, Tscal *__restrict cs) {
206 using namespace shamrock::sph;
207 Tscal rho_a = rho(i);
208 Tscal P_a = EOS::pressure(gamma, K, rho_a);
209 Tscal cs_a = EOS::soundspeed(gamma, K, rho_a);
210 P[i] = P_a;
211 cs[i] = cs_a;
212 });
213 });
214
215 } else if (
216 SolverEOS_LocallyIsothermal *eos_config
217 = std::get_if<SolverEOS_LocallyIsothermal>(&solver_config.eos_config.config)) {
218
220
221 u32 isoundspeed_interf = ghost_layout.get_field_idx<Tscal>("soundspeed");
222
223 storage.merged_patchdata_ghost.get().for_each([&](u64 id, PatchDataLayer &mpdat) {
225 = shambase::get_check_ref(storage.pressure).get_field(id).get_buf();
227 = shambase::get_check_ref(storage.soundspeed).get_field(id).get_buf();
228 sham::DeviceBuffer<Tscal> &buf_uint = mpdat.get_field_buf_ref<Tscal>(iuint_interf);
229 auto rho_getter = rho_getter_gen(mpdat);
230 sham::DeviceBuffer<Tscal> &buf_cs0 = mpdat.get_field_buf_ref<Tscal>(isoundspeed_interf);
231
232 u32 total_elements
233 = shambase::get_check_ref(storage.part_counts_with_ghost).indexes.get(id);
234
236 q,
237 sham::MultiRef{rho_getter, buf_uint, buf_cs0},
238 sham::MultiRef{buf_P, buf_cs},
239 total_elements,
240 [](u32 i,
241 auto rho,
242 const Tscal *__restrict U,
243 const Tscal *__restrict cs0,
244 Tscal *__restrict P,
245 Tscal *__restrict cs) {
246 using namespace shamrock::sph;
247
248 Tscal cs_out = cs0[i];
249 Tscal rho_a = rho(i);
250
251 Tscal P_a = EOS::pressure_from_cs(cs_out * cs_out, rho_a);
252
253 P[i] = P_a;
254 cs[i] = cs_out;
255 });
256 });
257
258 } else if (
259 SolverEOS_LocallyIsothermalLP07 *eos_config
260 = std::get_if<SolverEOS_LocallyIsothermalLP07>(&solver_config.eos_config.config)) {
261
263
264 storage.merged_patchdata_ghost.get().for_each([&](u64 id, PatchDataLayer &mpdat) {
265 auto &mfield = storage.merged_xyzh.get().get(id);
266
267 sham::DeviceBuffer<Tvec> &buf_xyz = mfield.template get_field_buf_ref<Tvec>(0);
268
270 = shambase::get_check_ref(storage.pressure).get_field(id).get_buf();
272 = shambase::get_check_ref(storage.soundspeed).get_field(id).get_buf();
273 sham::DeviceBuffer<Tscal> &buf_uint = mpdat.get_field_buf_ref<Tscal>(iuint_interf);
274 auto rho_getter = rho_getter_gen(mpdat);
275
276 Tscal cs0 = eos_config->cs0;
277 Tscal r0sq = eos_config->r0 * eos_config->r0;
278 Tscal mq = -eos_config->q;
279
280 u32 total_elements
281 = shambase::get_check_ref(storage.part_counts_with_ghost).indexes.get(id);
282
284 q,
285 sham::MultiRef{rho_getter, buf_uint, buf_xyz},
286 sham::MultiRef{buf_P, buf_cs},
287 total_elements,
288 [cs0, r0sq, mq](
289 u32 i,
290 auto rho,
291 const Tscal *__restrict U,
292 const Tvec *__restrict xyz,
293 Tscal *__restrict P,
294 Tscal *__restrict cs) {
295 using namespace shamrock::sph;
296
297 Tvec R = xyz[i];
298 Tscal rho_a = rho(i);
299
300 Tscal Rsq = sycl::dot(R, R);
301 Tscal cs_sq = EOS::soundspeed_sq(cs0 * cs0, Rsq / r0sq, mq);
302 Tscal cs_out = sycl::sqrt(cs_sq);
303
304 Tscal P_a = EOS::pressure_from_cs(cs_sq, rho_a);
305
306 P[i] = P_a;
307 cs[i] = cs_out;
308 });
309 });
310
311 } else if (
312 SolverEOS_LocallyIsothermalFA2014 *eos_config
313 = std::get_if<SolverEOS_LocallyIsothermalFA2014>(&solver_config.eos_config.config)) {
314
315 Tscal _G = solver_config.get_constant_G();
316
318
319 auto &sink_parts = storage.sinks.get();
320 std::vector<Tvec> sink_pos;
321 std::vector<Tscal> sink_mass;
322 u32 sink_cnt = 0;
323
324 for (auto &s : sink_parts) {
325 sink_pos.push_back(s.pos);
326 sink_mass.push_back(s.mass);
327 sink_cnt++;
328 }
329
330 sycl::buffer<Tvec> sink_pos_buf{sink_pos};
331 sycl::buffer<Tscal> sink_mass_buf{sink_mass};
332
333 storage.merged_patchdata_ghost.get().for_each([&](u64 id, PatchDataLayer &mpdat) {
334 auto &mfield = storage.merged_xyzh.get().get(id);
335
336 sham::DeviceBuffer<Tvec> &buf_xyz = mfield.template get_field_buf_ref<Tvec>(0);
337
339 = shambase::get_check_ref(storage.pressure).get_field(id).get_buf();
341 = shambase::get_check_ref(storage.soundspeed).get_field(id).get_buf();
342 sham::DeviceBuffer<Tscal> &buf_uint = mpdat.get_field_buf_ref<Tscal>(iuint_interf);
343 auto rho_getter = rho_getter_gen(mpdat);
344
345 // TODO: Use the complex kernel call when implemented
346
347 sham::EventList depends_list;
348
349 auto P = buf_P.get_write_access(depends_list);
350 auto cs = buf_cs.get_write_access(depends_list);
351 auto rho = rho_getter.get_read_access(depends_list);
352 auto U = buf_uint.get_read_access(depends_list);
353 auto xyz = buf_xyz.get_read_access(depends_list);
354
355 u32 total_elements
356 = shambase::get_check_ref(storage.part_counts_with_ghost).indexes.get(id);
357
358 auto e = q.submit(depends_list, [&](sycl::handler &cgh) {
359 sycl::accessor spos{sink_pos_buf, cgh, sycl::read_only};
360 sycl::accessor smass{sink_mass_buf, cgh, sycl::read_only};
361 u32 scount = sink_cnt;
362
363 Tscal h_over_r = eos_config->h_over_r;
364 Tscal G = _G;
365
366 cgh.parallel_for(sycl::range<1>{total_elements}, [=](sycl::item<1> item) {
367 using namespace shamrock::sph;
368
369 Tvec R = xyz[item];
370 Tscal rho_a = rho(item.get_linear_id());
371
372 Tscal mpotential = 0;
373 for (u32 i = 0; i < scount; i++) {
374 Tvec s_r = spos[i] - R;
375 Tscal s_m = smass[i];
376 Tscal s_r_abs = sycl::length(s_r);
377 mpotential += G * s_m / s_r_abs;
378 }
379
380 Tscal cs_out = h_over_r * sycl::sqrt(mpotential);
381 Tscal P_a = EOS::pressure_from_cs(cs_out * cs_out, rho_a);
382
383 P[item] = P_a;
384 cs[item] = cs_out;
385 });
386 });
387
388 buf_P.complete_event_state(e);
389 buf_cs.complete_event_state(e);
390 rho_getter.complete_event_state(e);
391 buf_uint.complete_event_state(e);
392 buf_xyz.complete_event_state(e);
393 });
394
395 } else if (
396 SolverEOS_LocallyIsothermalFA2014Extended *eos_config
397 = std::get_if<SolverEOS_LocallyIsothermalFA2014Extended>(
398 &solver_config.eos_config.config)) {
399
400 Tscal _cs0 = eos_config->cs0;
401 Tscal _r0 = eos_config->r0;
402 Tscal _q = eos_config->q;
403 u32 n_sinks = eos_config->n_sinks;
404
406
407 auto &sink_parts = storage.sinks.get();
408 std::vector<Tvec> sink_pos;
409 std::vector<Tscal> sink_mass;
410 u32 sink_cnt = 0;
411
412 for (auto &s : sink_parts) {
413 sink_pos.push_back(s.pos);
414 sink_mass.push_back(s.mass);
415 sink_cnt++;
416 if (sink_pos.size() >= n_sinks) { // We only consider the first n_sinks sinks
417 break;
418 }
419 }
420
421 if (sink_cnt == 0) {
423 "No sinks found for the equation of state");
424 }
425
426 sycl::buffer<Tvec> sink_pos_buf{sink_pos};
427 sycl::buffer<Tscal> sink_mass_buf{sink_mass};
428
429 storage.merged_patchdata_ghost.get().for_each([&](u64 id, PatchDataLayer &mpdat) {
430 auto &mfield = storage.merged_xyzh.get().get(id);
431
432 sham::DeviceBuffer<Tvec> &buf_xyz = mfield.template get_field_buf_ref<Tvec>(0);
433
435 = shambase::get_check_ref(storage.pressure).get_field(id).get_buf();
437 = shambase::get_check_ref(storage.soundspeed).get_field(id).get_buf();
438 sham::DeviceBuffer<Tscal> &buf_uint = mpdat.get_field_buf_ref<Tscal>(iuint_interf);
439 auto rho_getter = rho_getter_gen(mpdat);
440
441 // TODO: Use the complex kernel call when implemented
442
443 sham::EventList depends_list;
444
445 auto P = buf_P.get_write_access(depends_list);
446 auto cs = buf_cs.get_write_access(depends_list);
447 auto rho = rho_getter.get_read_access(depends_list);
448 auto U = buf_uint.get_read_access(depends_list);
449 auto xyz = buf_xyz.get_read_access(depends_list);
450
451 u32 total_elements
452 = shambase::get_check_ref(storage.part_counts_with_ghost).indexes.get(id);
453
454 auto e = q.submit(depends_list, [&](sycl::handler &cgh) {
455 sycl::accessor spos{sink_pos_buf, cgh, sycl::read_only};
456 sycl::accessor smass{sink_mass_buf, cgh, sycl::read_only};
457 u32 scount = sink_cnt;
458
459 Tscal cs0 = _cs0;
460 Tscal r0 = _r0;
461 Tscal q = _q;
462
463 Tscal inv_r0_q = 1. / sycl::pow(r0, q);
464
465 cgh.parallel_for(sycl::range<1>{total_elements}, [=](sycl::item<1> item) {
466 using namespace shamrock::sph;
467
468 Tvec R = xyz[item];
469 Tscal rho_a = rho(item.get_linear_id());
470
471 Tscal sink_mass_sum = 0;
472 Tscal pot_sum = 0;
473 for (u32 i = 0; i < scount; i++) {
474 Tvec s_r = spos[i] - R;
475 Tscal s_m = smass[i];
476 Tscal s_r_abs = sycl::length(s_r);
477 sink_mass_sum += s_m;
478 pot_sum += s_m / s_r_abs;
479 }
480
481 Tscal cs_out = cs0 * inv_r0_q * sycl::pow(pot_sum / sink_mass_sum, q);
482 Tscal P_a = EOS::pressure_from_cs(cs_out * cs_out, rho_a);
483
484 P[item] = P_a;
485 cs[item] = cs_out;
486 });
487 });
488
489 buf_P.complete_event_state(e);
490 buf_cs.complete_event_state(e);
491 rho_getter.complete_event_state(e);
492 buf_uint.complete_event_state(e);
493 buf_xyz.complete_event_state(e);
494 });
495
496 } else if (
497 SolverEOS_Fermi *eos_config
498 = std::get_if<SolverEOS_Fermi>(&solver_config.eos_config.config)) {
499
500 using EOS = shamphys::EOS_Fermi<Tscal>;
501
502 storage.merged_patchdata_ghost.get().for_each([&](u64 id, PatchDataLayer &mpdat) {
504 = shambase::get_check_ref(storage.pressure).get_field(id).get_buf();
506 = shambase::get_check_ref(storage.soundspeed).get_field(id).get_buf();
507 auto rho_getter = rho_getter_gen(mpdat);
508
509 u32 total_elements
510 = shambase::get_check_ref(storage.part_counts_with_ghost).indexes.get(id);
511
512 using namespace shamunits;
513 auto unit_sys = *solver_config.unit_sys;
514
515 Tscal mass = unit_sys.template to<units::kilogram>();
516 Tscal length = unit_sys.template to<units::metre>();
517 Tscal time = unit_sys.template to<units::second>();
518
519 Tscal pressure_unit = mass / length / (time * time);
520 Tscal density_unit = mass / (length * length * length);
521 Tscal velocity_unit = length / time;
522
524 q,
525 sham::MultiRef{rho_getter},
526 sham::MultiRef{buf_P, buf_cs},
527 total_elements,
528 [mu_e = eos_config->mu_e, density_unit, pressure_unit, velocity_unit](
529 u32 i, auto rho, Tscal *__restrict P, Tscal *__restrict cs) {
530 Tscal rho_a = rho(i);
531 auto const res = EOS::pressure_and_soundspeed(mu_e, rho_a * density_unit);
532 P[i] = res.pressure / pressure_unit;
533 cs[i] = res.soundspeed / velocity_unit;
534 });
535 });
536
537 } else {
539 }
540}
541
542template<class Tvec, template<class> class SPHKernel>
544
545 NamedStackEntry stack_loc{"compute eos"};
546
547 Tscal gpart_mass = solver_config.gpart_mass;
548
549 using namespace shamrock;
550 using namespace shamrock::patch;
551
553 = shambase::get_check_ref(storage.ghost_layout.get());
554 u32 ihpart_interf = ghost_layout.get_field_idx<Tscal>("hpart");
555
556 shamrock::SchedulerUtility utility(scheduler());
557
558 shambase::DistributedData<u32> &counts_with_ghosts
559 = shambase::get_check_ref(storage.part_counts_with_ghost).indexes;
560
561 shambase::get_check_ref(storage.pressure).ensure_sizes(counts_with_ghosts);
562 shambase::get_check_ref(storage.soundspeed).ensure_sizes(counts_with_ghosts);
563
564 if (solver_config.dust_config.has_epsilon_field()) {
565
566 u32 iepsilon_interf = ghost_layout.get_field_idx<Tscal>("epsilon");
567 u32 nvar_dust = solver_config.dust_config.get_dust_nvar();
568
569 compute_eos_internal([&](PatchDataLayer &mpdat) {
571 mpdat.get_field_buf_ref<Tscal>(ihpart_interf),
572 mpdat.get_field_buf_ref<Tscal>(iepsilon_interf),
573 nvar_dust,
574 gpart_mass,
575 Kernel::hfactd};
576 });
577 } else {
578 compute_eos_internal([&](PatchDataLayer &mpdat) {
580 mpdat.get_field_buf_ref<Tscal>(ihpart_interf), gpart_mass, Kernel::hfactd};
581 });
582 }
583}
584
585using namespace shammath;
589
constexpr const char * xyz
Position field (3D coordinates)
std::uint32_t u32
32 bit unsigned integer
std::uint64_t u64
64 bit unsigned integer
A buffer allocated in USM (Unified Shared Memory)
void complete_event_state(sycl::event e) const
Complete the event state of the buffer.
T * get_write_access(sham::EventList &depends_list, SourceLocation src_loc=SourceLocation{})
Get a read-write pointer to the buffer's data.
const T * get_read_access(sham::EventList &depends_list, SourceLocation src_loc=SourceLocation{}) const
Get a read-only pointer to the buffer's data.
A SYCL queue associated with a device and a context.
sycl::event submit(Fct &&fct)
Submits a kernel to the SYCL queue.
DeviceQueue & get_queue(u32 id=0)
Get a reference to a DeviceQueue.
Class to manage a list of SYCL events.
Definition EventList.hpp:31
Represents a collection of objects distributed across patches identified by a u64 id.
Module for computing equation of state quantities.
void compute_eos()
Computes pressure and sound speed from equation of state.
u32 get_field_idx(const std::string &field_name) const
Get the field id if matching name & type.
PatchDataLayer container class, the layout is described in patchdata_layout.
This header file contains utility functions related to exception handling in the code.
void kernel_call(sham::DeviceQueue &q, RefIn in, RefOut in_out, u32 n, Functor &&func, SourceLocation &&callsite=SourceLocation{})
Submit a kernel to a SYCL queue.
void throw_with_loc(std::string message, SourceLocation loc=SourceLocation{})
Throw an exception and append the source location to it.
T & get_check_ref(const std::unique_ptr< T > &ptr, SourceLocation loc=SourceLocation())
Takes a std::unique_ptr and returns a reference to the object it holds. It throws a std::runtime_erro...
Definition memory.hpp:110
void throw_unimplemented(SourceLocation loc=SourceLocation{})
Throw a std::runtime_error saying that the function is unimplemented.
namespace for math utility
Definition AABB.hpp:26
namespace for the main framework
Definition __init__.py:1
namespace containing the units library
sph kernels
A class that references multiple buffers or similar objects.
Adiabatic equation of state.
Definition eos.hpp:45
Fermi Gas EoS.
Definition eos.hpp:196
Isothermal equation of state.
Definition eos.hpp:32
Locally isothermal equation of state with radial dependence.
Definition eos.hpp:87
Polytropic equation of state.
Definition eos.hpp:66