34 using Tscal = shambase::VecComponent<Tvec>;
44 : model(model), ctx(model.ctx), solver(model.solver) {};
46 auto get_potential_energy() -> Tscal {
48 auto dev_sched_ptr = shamsys::instance::get_compute_scheduler_ptr();
51 const u32 ixyz = sched.pdl_old().template get_field_idx<Tvec>(
"xyz");
52 const Tscal pmass = solver.solver_config.
gpart_mass;
61 std::vector<GravSource> grav_sources;
63 if (!solver.storage.sinks.is_empty()) {
64 for (
const auto &sink : solver.storage.sinks.get()) {
65 grav_sources.push_back({sink.pos, sink.mass});
70 using EF_PointMass =
typename SolverConfigExtForce::PointMass;
71 using EF_LenseThirring =
typename SolverConfigExtForce::LenseThirring;
72 using EF_ShearingBoxForce =
typename SolverConfigExtForce::ShearingBoxForce;
74 for (
const auto &var_force : solver.solver_config.
ext_force_config.ext_forces) {
75 if (
const EF_PointMass *ext_force = std::get_if<EF_PointMass>(&var_force.val)) {
76 grav_sources.push_back({Tvec{}, ext_force->central_mass});
78 const EF_LenseThirring *ext_force
79 = std::get_if<EF_LenseThirring>(&var_force.val)) {
80 grav_sources.push_back({Tvec{}, ext_force->central_mass});
84 if (!grav_sources.empty()) {
86 using Tscal4 = sycl::vec<Tscal, 4>;
87 std::vector<Tscal4> sources{};
89 for (
const auto &grav_source : grav_sources) {
102 u32 len = pdat.get_obj_cnt();
114 source_count = sources.size()](
116 const Tvec *__restrict xyz,
117 const Tscal4 *__restrict sources,
118 Tscal *__restrict epot_part) {
124 for (u32 j = 0; j < source_count; ++j) {
125 Tscal4 source = sources[j];
128 sink_pos = {source.x(), source.y(), source.z()};
130 loc_epot += -pmass * G * smass / sycl::length(xyz[i] - sink_pos);
132 epot_part[i] = loc_epot;
139 Tscal tot_epot = shamalgs::collective::allreduce_sum(epot);
143 for (
size_t i = 0; i < grav_sources.size(); ++i) {
144 for (
size_t j = i + 1; j < grav_sources.size(); ++j) {
145 const auto &sink1 = grav_sources[i];
146 const auto &sink2 = grav_sources[j];
148 Tvec delta = sink1.pos - sink2.pos;
149 Tscal d = sycl::length(delta);
void kernel_call(sham::DeviceQueue &q, RefIn in, RefOut in_out, u32 n, Functor &&func, SourceLocation &&callsite=SourceLocation{})
Submit a kernel to a SYCL queue.