Shamrock 2025.10.0
Astrophysical Code
Loading...
Searching...
No Matches
NeighGraphLinkField.hpp
Go to the documentation of this file.
1// -------------------------------------------------------//
2//
3// SHAMROCK code for hydrodynamics
4// Copyright (c) 2021-2026 Timothée David--Cléris <tim.shamrock@proton.me>
5// SPDX-License-Identifier: CeCILL Free Software License Agreement v2.1
6// Shamrock is licensed under the CeCILL 2.1 License, see LICENSE for more information
7//
8// -------------------------------------------------------//
9
10#pragma once
11
25#include "shambackends/sycl.hpp"
28
30
31 template<class T>
33 public:
34 sham::DeviceBuffer<T> link_graph_field;
35 u32 link_count;
36 u32 nvar;
37
38 void resize(NeighGraph &graph) {
39 if (link_count != graph.link_count) {
40 link_count = graph.link_count;
41 link_graph_field.resize(link_count * nvar);
42 }
43 }
44 void resize(u32 count) {
45 if (link_count != count) {
46 link_count = count;
47 link_graph_field.resize(link_count * nvar);
48 }
49 }
50
52 : link_graph_field(0, shamsys::instance::get_compute_scheduler_ptr()), nvar(nvar),
53 link_count(0) {}
54
56 : link_graph_field(graph.link_count, shamsys::instance::get_compute_scheduler_ptr()),
57 link_count(graph.link_count), nvar(1) {}
58
60 : link_graph_field(
61 graph.link_count * nvar, shamsys::instance::get_compute_scheduler_ptr()),
62 link_count(graph.link_count), nvar(nvar) {}
63
64 NeighGraphLinkField(u32 link_count, u32 nvar)
65 : link_graph_field(link_count * nvar, shamsys::instance::get_compute_scheduler_ptr()),
66 link_count(link_count), nvar(nvar) {}
67
68 inline auto get_read_access(sham::EventList &deps) const {
69 return link_graph_field.get_read_access(deps);
70 }
71 inline auto get_write_access(sham::EventList &deps) {
72 return link_graph_field.get_write_access(deps);
73 }
74 inline void complete_event_state(sycl::event e) const {
75 return link_graph_field.complete_event_state(e);
76 }
77 };
78
79 template<class LinkFieldCompute, class T>
80 inline void ddupdate_link_field(
81 sham::DeviceScheduler_ptr dev_sched,
85 StackEntry stack_loc{};
86
87 auto &result = neigh_graph_field;
88
89 shambase::DistributedData<u32> counts = graph.map<u32>([&](u64 id, u32 block_count) {
90 return graph.get(id).obj_cnt;
91 });
92
94 dev_sched,
95 sham::DDMultiRef{graph, fcomp},
96 sham::DDMultiRef{neigh_graph_field},
97 counts,
98 [](u32 id_a, auto link_iter, auto compute, auto acc_link_field) {
99 link_iter.for_each_object_link_id(id_a, [&](u32 id_b, u32 link_id) {
100 acc_link_field[link_id] = compute.get_link_field_val(id_a, id_b);
101 });
102 });
103 }
104
105 template<class LinkFieldCompute, class T, class... Args>
106 inline void update_link_field(
108 sham::EventList &depends_list,
109 sham::EventList &result_list,
110 NeighGraphLinkField<T> &neigh_graph_field,
111 NeighGraph &graph,
112 Args &&...args) {
113 StackEntry stack_loc{};
114
115 auto &result = neigh_graph_field;
116
117 result.resize(graph);
118
119 auto acc_link_field = result.link_graph_field.get_write_access(depends_list);
120 auto link_iter = graph.get_read_access(depends_list);
121
122 LinkFieldCompute compute(std::forward<Args>(args)...);
123
124 auto e = q.submit(depends_list, [&](sycl::handler &cgh) {
125 shambase::parallel_for(cgh, graph.obj_cnt, "compute link field", [=](u32 id_a) {
126 link_iter.for_each_object_link_id(id_a, [&](u32 id_b, u32 link_id) {
127 acc_link_field[link_id] = compute.get_link_field_val(id_a, id_b);
128 });
129 });
130 });
131
132 result_list.add_event(e);
133 result.link_graph_field.complete_event_state(e);
134 graph.complete_event_state(e);
135 }
136 template<class LinkFieldCompute, class T, class... Args>
137 inline void update_link_field_indep_nvar(
139 sham::EventList &depends_list,
140 sham::EventList &result_list,
141 NeighGraphLinkField<T> &neigh_graph_field,
142 NeighGraph &graph,
143 u32 nvar,
144 Args &&...args) {
145 StackEntry stack_loc{};
146
147 auto &result = neigh_graph_field;
148
149 result.resize(graph);
150
151 auto acc_link_field = result.link_graph_field.get_write_access(depends_list);
152 auto link_iter = graph.get_read_access(depends_list);
153
154 LinkFieldCompute compute(nvar, std::forward<Args>(args)...);
155
156 auto e = q.submit(depends_list, [&](sycl::handler &cgh) {
157 shambase::parallel_for(
158 cgh, graph.obj_cnt * nvar, "compute link field indep nvar", [=](u32 idvar_a) {
159 const u32 id_cell_a = idvar_a / nvar;
160 const u32 nvar_loc = idvar_a % nvar;
161
162 link_iter.for_each_object_link_id(id_cell_a, [&](u32 id_cell_b, u32 link_id) {
163 acc_link_field[link_id * nvar + nvar_loc] = compute.get_link_field_val(
164 id_cell_a * nvar + nvar_loc, id_cell_b * nvar + nvar_loc);
165 });
166 });
167 });
168
169 result_list.add_event(e);
170 result.link_graph_field.complete_event_state(e);
171 graph.complete_event_state(e);
172 }
173
174 template<class LinkFieldCompute, class T, class... Args>
175 NeighGraphLinkField<T> compute_link_field(
177 sham::EventList &depends_list,
178 sham::EventList &result_list,
179 NeighGraph &graph,
180 Args &&...args) {
181 StackEntry stack_loc{};
182
183 NeighGraphLinkField<T> result{graph};
184
185 auto acc_link_field = result.link_graph_field.get_write_access(depends_list);
186 auto link_iter = graph.get_read_access(depends_list);
187
188 auto e = q.submit(depends_list, [&](sycl::handler &cgh) {
189 LinkFieldCompute compute(cgh, std::forward<Args>(args)...);
190
191 shambase::parallel_for(cgh, graph.obj_cnt, "compute link field", [=](u32 id_a) {
192 link_iter.for_each_object_link_id(id_a, [&](u32 id_b, u32 link_id) {
193 acc_link_field[link_id] = compute.get_link_field_val(id_a, id_b);
194 });
195 });
196 });
197
198 result_list.add_event(e);
199 result.link_graph_field.complete_event_state(e);
200 graph.complete_event_state(e);
201
202 return result;
203 }
204
205 template<class LinkFieldCompute, class T, class... Args>
206 NeighGraphLinkField<T> compute_link_field_indep_nvar(
208 sham::EventList &depends_list,
209 sham::EventList &result_list,
210 NeighGraph &graph,
211 u32 nvar,
212 Args &&...args) {
213
214 StackEntry stack_loc{};
215
216 NeighGraphLinkField<T> result{graph, nvar};
217
218 auto acc_link_field = result.link_graph_field.get_write_access(depends_list);
219 auto link_iter = graph.get_read_access(depends_list);
220
221 auto e = q.submit(depends_list, [&](sycl::handler &cgh) {
222 LinkFieldCompute compute(cgh, nvar, std::forward<Args>(args)...);
223
224 shambase::parallel_for(
225 cgh, graph.obj_cnt * nvar, "compute link field indep nvar", [=](u32 idvar_a) {
226 const u32 id_cell_a = idvar_a / nvar;
227 const u32 nvar_loc = idvar_a % nvar;
228
229 link_iter.for_each_object_link_id(id_cell_a, [&](u32 id_cell_b, u32 link_id) {
230 acc_link_field[link_id * nvar + nvar_loc] = compute.get_link_field_val(
231 id_cell_a * nvar + nvar_loc, id_cell_b * nvar + nvar_loc);
232 });
233 });
234 });
235
236 result_list.add_event(e);
237 result.link_graph_field.complete_event_state(e);
238 graph.complete_event_state(e);
239
240 return result;
241 }
242
243 /*
244 template<class Tvec>
245 class FaceShiftInfo{
246 sycl::buffer<Tvec> link_shift_a;
247 sycl::buffer<Tvec> link_shift_b;
248 u32 link_count;
249
250 FaceShiftInfo(NeighGraph & graph):
251 link_shift_a(graph.link_count),
252 link_shift_b(graph.link_count),
253 link_count(graph.link_count) {}
254 };
255
256 template<class Tvec, class TgridVec, class AMRBLock>
257 FaceShiftInfo<Tvec> get_face_shift_infos(sycl::queue &q,NeighGraph & graph,
258 sycl::buffer<TgridVec> &buf_block_min,
259 sycl::buffer<TgridVec> &buf_block_max,
260 ){
261
262 FaceShiftInfo<Tvec> shifts (graph);
263
264 q.submit([&](sycl::handler &cgh) {
265 NeighGraphLinkiterator link_iter {graph, cgh};
266 LinkFieldCompute compute (cgh, std::forward<Args>(args)...);
267
268 sycl::accessor acc_link_field {result.template link_graph_field, cgh, sycl::write_only,
269 sycl::no_init};
270
271 shambase::parallel_for(cgh, graph.obj_cnt, "compute link field", [=](u32 id_a) {
272
273 link_iter.for_each_object_link(id_a, [&](u32 id_b, u32 link_id){
274 acc_link_field[link_id] = compute.get_link_field_val(id_a, id_b);
275 });
276
277 });
278 });
279
280 return shifts;
281
282 }
283 */
284
285} // namespace shammodels::basegodunov::modules
Header file describing a Node Instance.
std::uint32_t u32
32 bit unsigned integer
std::uint64_t u64
64 bit unsigned integer
A buffer allocated in USM (Unified Shared Memory)
void complete_event_state(sycl::event e) const
Complete the event state of the buffer.
void resize(size_t new_size, bool keep_data=true)
Resizes the buffer to a given size.
T * get_write_access(sham::EventList &depends_list, SourceLocation src_loc=SourceLocation{})
Get a read-write pointer to the buffer's data.
const T * get_read_access(sham::EventList &depends_list, SourceLocation src_loc=SourceLocation{}) const
Get a read-only pointer to the buffer's data.
A SYCL queue associated with a device and a context.
sycl::event submit(Fct &&fct)
Submits a kernel to the SYCL queue.
Class to manage a list of SYCL events.
Definition EventList.hpp:31
void add_event(sycl::event e)
Add an event to the list of events.
Definition EventList.hpp:87
Represents a collection of objects distributed across patches identified by a u64 id.
DistributedData< Tmap > map(std::function< Tmap(u64, T &)> map_func)
Apply a function to all objects in the collection and return a new collection containing the results.
T & get(u64 id)
Returns a reference to an object in the collection.
void distributed_data_kernel_call(sham::DeviceScheduler_ptr dev_sched, RefIn in, RefOut in_out, const shambase::DistributedData< index_t > &thread_counts, Functor &&func)
A variant of sham::kernel_call for distributed data.
namespace for the basegodunov model modules
A variant of sham::MultiRef for distributed data.