Shamrock 2025.10.0
Astrophysical Code
Loading...
Searching...
No Matches
SGDirectPlummer.cpp
Go to the documentation of this file.
1// -------------------------------------------------------//
2//
3// SHAMROCK code for hydrodynamics
4// Copyright (c) 2021-2026 Timothée David--Cléris <tim.shamrock@proton.me>
5// SPDX-License-Identifier: CeCILL Free Software License Agreement v2.1
6// Shamrock is licensed under the CeCILL 2.1 License, see LICENSE for more information
7//
8// -------------------------------------------------------//
9
19
21
22 template<class Tvec>
25
26 auto edges = get_edges();
27
28 if (edges.sizes.indexes.get_ids().size() != 1) {
30 "Self gravity direct mode only supports one patch so far, current number "
31 "of patches is : "
32 + std::to_string(edges.sizes.indexes.get_ids().size()));
33 }
34
35 const Tscal G = edges.constant_G.data;
36 const Tscal gpart_mass = edges.gpart_mass.data;
37
38 const Tscal gravitational_softening = epsilon * epsilon;
39
40 edges.sizes.indexes.for_each([&](u64 id, const u64 &n) {
41 PatchDataField &xyz = edges.field_xyz.get_field(id);
42 PatchDataField &axyz_ext = edges.field_axyz_ext.get_field(id);
43
44 if (reference_mode) {
45 std::vector<Tvec> xyz_vec = xyz.get_buf().copy_to_stdvec();
46 std::vector<Tvec> axyz_ext_vec = axyz_ext.get_buf().copy_to_stdvec();
47
48 for (u64 i = 0; i < n; i++) {
49 Tvec force{0.0f};
50 for (u64 j = 0; j < n; j++) {
51 if (i == j) {
52 continue;
53 }
54
55 Tvec R = xyz_vec[j] - xyz_vec[i];
56 const Tscal r_inv = sycl::rsqrt(
57 R.x() * R.x() + R.y() * R.y() + R.z() * R.z()
58 + gravitational_softening);
59 force += gpart_mass * r_inv * r_inv * r_inv * R;
60 }
61 axyz_ext_vec[i] += force * G;
62 }
63
64 axyz_ext.get_buf().copy_from_stdvec(axyz_ext_vec);
65
66 } else {
67
68 const u32 group_size = 32;
69 const u32 group_cnt = shambase::group_count(static_cast<u32>(n), group_size);
70 const u32 corrected_len = group_cnt * group_size;
71
72 sham::kernel_call_hndl(
73 shamsys::instance::get_compute_scheduler_ptr()->get_queue(),
74 sham::MultiRef{xyz.get_buf()},
75 sham::MultiRef{axyz_ext.get_buf()},
76 static_cast<u32>(n),
77 [corrected_len, group_size, G, gpart_mass, gravitational_softening](
78 u32 Npart, const Tvec *__restrict xyz, Tvec *__restrict axyz_ext) {
79 auto range = sycl::nd_range<1>{corrected_len, group_size};
80
81 return [=](sycl::handler &cgh) {
82 using vec4 = sycl::vec<Tscal, 4>;
83
84 auto position_scratch
85 = sycl::local_accessor<vec4>{sycl::range<1>{group_size}, cgh};
86
87 cgh.parallel_for(range, [=](sycl::nd_item<1> tid) {
88 const u64 global_id = tid.get_global_linear_id();
89 const u64 local_id = tid.get_local_linear_id();
90
91 Tvec force{0.0f};
92
93 const Tvec my_particle
94 = (global_id < Npart) ? xyz[global_id] : Tvec{0.0f};
95
96 for (u32 offset = 0; offset < Npart; offset += group_size) {
97
98 if (offset + local_id < Npart) {
99 position_scratch[local_id]
100 = vec4{xyz[offset + local_id], gpart_mass};
101 } else {
102 position_scratch[local_id] = vec4{0.0f};
103 }
104
105 sycl::group_barrier(tid.get_group());
106
107 for (u32 i = 0; i < group_size; ++i) {
108 const Tvec p
109 = position_scratch[i].template swizzle<0, 1, 2>();
110 const Tvec R = p - my_particle;
111
112 const Tscal r_inv = sycl::rsqrt(
113 R.x() * R.x() + R.y() * R.y() + R.z() * R.z()
114 + gravitational_softening);
115
116 if (global_id != offset + i) {
117 force += position_scratch[i].w() * r_inv * r_inv * r_inv
118 * R;
119 }
120 }
121
122 sycl::group_barrier(tid.get_group());
123 }
124
125 if (global_id < Npart) {
126 axyz_ext[global_id] += force * G;
127 }
128 });
129 };
130 });
131 }
132 });
133 }
134} // namespace shammodels::sph::modules
135
std::uint32_t u32
32 bit unsigned integer
std::uint64_t u64
64 bit unsigned integer
void copy_from_stdvec(const std::vector< T > &vec)
Copy the content of a std::vector into the buffer.
void _impl_evaluate_internal() override
evaluate the node
constexpr u32 group_count(u32 len, u32 group_size)
Calculates the number of groups based on the length and group size.
Definition integer.hpp:125
void throw_with_loc(std::string message, SourceLocation loc=SourceLocation{})
Throw an exception and append the source location to it.
namespace for the sph model modules
#define __shamrock_stack_entry()
Macro to create a stack entry.
A class that references multiple buffers or similar objects.