Shamrock 2025.10.0
Astrophysical Code
Loading...
Searching...
No Matches
AnalysisEnergyPotential.hpp
Go to the documentation of this file.
1// -------------------------------------------------------//
2//
3// SHAMROCK code for hydrodynamics
4// Copyright (c) 2021-2026 Timothée David--Cléris <tim.shamrock@proton.me>
5// SPDX-License-Identifier: CeCILL Free Software License Agreement v2.1
6// Shamrock is licensed under the CeCILL 2.1 License, see LICENSE for more information
7//
8// -------------------------------------------------------//
9
10#pragma once
11
19#include "shambase/memory.hpp"
23#include "shambackends/math.hpp"
27#include <utility>
28
30
31 template<class Tvec, template<class> class SPHKernel>
33 public:
34 using Tscal = shambase::VecComponent<Tvec>;
36
38
40 Solver &solver;
41 ShamrockCtx &ctx;
42
44 : model(model), ctx(model.ctx), solver(model.solver) {};
45
46 auto get_potential_energy() -> Tscal {
47 PatchScheduler &sched = shambase::get_check_ref(ctx.sched);
48 auto dev_sched_ptr = shamsys::instance::get_compute_scheduler_ptr();
49 sham::DeviceQueue &q = shambase::get_check_ref(dev_sched_ptr).get_queue();
50
51 const u32 ixyz = sched.pdl_old().template get_field_idx<Tvec>("xyz");
52 const Tscal pmass = solver.solver_config.gpart_mass;
53
54 Tscal epot = 0;
55
56 struct GravSource {
57 Tvec pos;
58 Tscal mass;
59 };
60
61 std::vector<GravSource> grav_sources;
62
63 if (!solver.storage.sinks.is_empty()) {
64 for (const auto &sink : solver.storage.sinks.get()) {
65 grav_sources.push_back({sink.pos, sink.mass});
66 }
67 }
68
69 using SolverConfigExtForce = typename Solver::Config::ExtForceConfig;
70 using EF_PointMass = typename SolverConfigExtForce::PointMass;
71 using EF_LenseThirring = typename SolverConfigExtForce::LenseThirring;
72 using EF_ShearingBoxForce = typename SolverConfigExtForce::ShearingBoxForce;
73
74 for (const auto &var_force : solver.solver_config.ext_force_config.ext_forces) {
75 if (const EF_PointMass *ext_force = std::get_if<EF_PointMass>(&var_force.val)) {
76 grav_sources.push_back({Tvec{}, ext_force->central_mass});
77 } else if (
78 const EF_LenseThirring *ext_force
79 = std::get_if<EF_LenseThirring>(&var_force.val)) {
80 grav_sources.push_back({Tvec{}, ext_force->central_mass});
81 }
82 }
83
84 if (!grav_sources.empty()) {
85
86 using Tscal4 = sycl::vec<Tscal, 4>;
87 std::vector<Tscal4> sources{};
88
89 for (const auto &grav_source : grav_sources) {
90 sources.push_back(
91 {grav_source.pos.x(),
92 grav_source.pos.y(),
93 grav_source.pos.z(),
94 grav_source.mass});
95 }
96
97 sham::DeviceBuffer<Tscal4> sources_buf(sources.size(), dev_sched_ptr);
98 sources_buf.copy_from_stdvec(sources);
99
100 sched.for_each_patchdata_nonempty([&](const shamrock::patch::Patch p,
102 u32 len = pdat.get_obj_cnt();
103
104 sham::DeviceBuffer<Tscal> epot_part(len, dev_sched_ptr);
105 sham::DeviceBuffer<Tvec> &xyz_buf = pdat.get_field_buf_ref<Tvec>(ixyz);
106
108 q,
109 sham::MultiRef{xyz_buf, sources_buf},
110 sham::MultiRef{epot_part},
111 len,
112 [pmass,
113 G = solver.solver_config.get_constant_G(),
114 source_count = sources.size()](
115 u32 i,
116 const Tvec *__restrict xyz,
117 const Tscal4 *__restrict sources,
118 Tscal *__restrict epot_part) {
119 Tscal loc_epot = 0;
120
121 Tscal smass;
122 Tvec sink_pos;
123
124 for (u32 j = 0; j < source_count; ++j) {
125 Tscal4 source = sources[j];
126
127 smass = source.w();
128 sink_pos = {source.x(), source.y(), source.z()};
129
130 loc_epot += -pmass * G * smass / sycl::length(xyz[i] - sink_pos);
131 }
132 epot_part[i] = loc_epot;
133 });
134
135 epot += shamalgs::primitives::sum(dev_sched_ptr, epot_part, 0, len);
136 });
137 }
138
139 Tscal tot_epot = shamalgs::collective::allreduce_sum(epot);
140
141 Tscal G = solver.solver_config.get_constant_G();
142
143 for (size_t i = 0; i < grav_sources.size(); ++i) {
144 for (size_t j = i + 1; j < grav_sources.size(); ++j) {
145 const auto &sink1 = grav_sources[i];
146 const auto &sink2 = grav_sources[j];
147
148 Tvec delta = sink1.pos - sink2.pos;
149 Tscal d = sycl::length(delta);
150
151 tot_epot += -G * sink1.mass * sink2.mass * sham::inv_sat_positive(d, 1e-16);
152 }
153 }
154
155 return tot_epot;
156 }
157 };
158
159} // namespace shammodels::sph::modules
MPI scheduler.
std::uint32_t u32
32 bit unsigned integer
The MPI scheduler.
A buffer allocated in USM (Unified Shared Memory)
void copy_from_stdvec(const std::vector< T > &vec)
Copy the content of a std::vector into the buffer.
A SYCL queue associated with a device and a context.
The shamrock SPH model.
Definition Model.hpp:55
The shamrock SPH model.
Definition Solver.hpp:61
PatchDataLayer container class, the layout is described in patchdata_layout.
T inv_sat_positive(T v, T minvsat=T{1e-9}, T satval=T{0.}) noexcept
inverse saturated (positive numbers only)
Definition math.hpp:841
void kernel_call(sham::DeviceQueue &q, RefIn in, RefOut in_out, u32 n, Functor &&func, SourceLocation &&callsite=SourceLocation{})
Submit a kernel to a SYCL queue.
T sum(const sham::DeviceScheduler_ptr &sched, const sham::DeviceBuffer< T > &buf1, u32 start_id, u32 end_id)
Compute the sum of elements in a device buffer within a specified range.
T & get_check_ref(const std::unique_ptr< T > &ptr, SourceLocation loc=SourceLocation())
Takes a std::unique_ptr and returns a reference to the object it holds. It throws a std::runtime_erro...
Definition memory.hpp:110
namespace for the sph model modules
A class that references multiple buffers or similar objects.
Tscal gpart_mass
The mass of each gas particle.
shammodels::ExtForceConfig< Tvec > ExtForceConfig
External force configuration.
Tscal get_constant_G()
Retrieves the value of the constant G based on the unit system.
ExtForceConfig ext_force_config
External force configuration.
Patch object that contain generic patch information.
Definition Patch.hpp:33