Shamrock 2025.10.0
Astrophysical Code
Loading...
Searching...
No Matches
PatchDataLayer.cpp
Go to the documentation of this file.
1// -------------------------------------------------------//
2//
3// SHAMROCK code for hydrodynamics
4// Copyright (c) 2021-2026 Timothée David--Cléris <tim.shamrock@proton.me>
5// SPDX-License-Identifier: CeCILL Free Software License Agreement v2.1
6// Shamrock is licensed under the CeCILL 2.1 License, see LICENSE for more information
7//
8// -------------------------------------------------------//
9
21#include "shambase/string.hpp"
27#include <vector>
28
29namespace shamrock::patch {
30
31 PatchDataLayer PatchDataLayer::mock_patchdata(
32 u64 seed, u32 obj_cnt, const std::shared_ptr<PatchDataLayerLayout> &pdl) {
33 PatchDataLayer pdat{pdl};
34
35 pdat.fields.clear();
36
37 pdat.pdl().for_each_field_any([&](auto &field) {
38 using f_t = typename std::remove_reference<decltype(field)>::type;
39 using base_t = typename f_t::field_T;
40
41 pdat.fields.push_back(
42 var_t{PatchDataField<base_t>::mock_field(seed, obj_cnt, field.name, field.nvar)});
43 });
44
45 return pdat;
46 }
47
48 void PatchDataLayer::init_fields() {
49
50 pdl().for_each_field_any([&](auto &field) {
51 using f_t = typename std::remove_reference<decltype(field)>::type;
52 using base_t = typename f_t::field_T;
53
54 fields.push_back(var_t{PatchDataField<base_t>(field.name, field.nvar)});
55 });
56 }
57
59 StackEntry stack_loc{};
60
61 for (u32 idx = 0; idx < fields.size(); idx++) {
62
63 std::visit(
64 [&](auto &field, auto &out_field) {
65 using t1 = typename std::remove_reference<decltype(field)>::type::Field_type;
66 using t2 =
67 typename std::remove_reference<decltype(out_field)>::type::Field_type;
68
69 if constexpr (std::is_same<t1, t2>::value) {
70 field.extract_element(pidx, out_field);
71 } else {
73 }
74 },
75 fields[idx].value,
76 out_pdat.fields[idx].value);
77 }
78 }
79
80 void PatchDataLayer::extract_elements(
81 const sham::DeviceBuffer<u32> &idxs, PatchDataLayer &out_pdat) {
82 StackEntry stack_loc{};
83
84 for (u32 idx = 0; idx < fields.size(); idx++) {
85
86 std::visit(
87 [&](auto &field, auto &out_field) {
88 using t1 = typename std::remove_reference<decltype(field)>::type::Field_type;
89 using t2 =
90 typename std::remove_reference<decltype(out_field)>::type::Field_type;
91
92 if constexpr (std::is_same<t1, t2>::value) {
93 field.extract_elements(idxs, out_field);
94 } else {
96 }
97 },
98 fields[idx].value,
99 out_pdat.fields[idx].value);
100 }
101 }
102
103 void PatchDataLayer::insert_elements(const PatchDataLayer &pdat) {
104
105 StackEntry stack_loc{};
106
107 for (u32 idx = 0; idx < fields.size(); idx++) {
108
109 std::visit(
110 [&](auto &field, auto &out_field) {
111 using t1 = typename std::remove_reference<decltype(field)>::type::Field_type;
112 using t2 =
113 typename std::remove_reference<decltype(out_field)>::type::Field_type;
114
115 if constexpr (std::is_same<t1, t2>::value) {
116 field.insert(out_field);
117 } else {
119 }
120 },
121 fields[idx].value,
122 pdat.fields[idx].value);
123 }
124 }
125
126 void PatchDataLayer::overwrite(PatchDataLayer &pdat, u32 obj_cnt) {
127 StackEntry stack_loc{};
128
129 for (u32 idx = 0; idx < fields.size(); idx++) {
130
131 std::visit(
132 [&](auto &field, auto &out_field) {
133 using t1 = typename std::remove_reference<decltype(field)>::type::Field_type;
134 using t2 =
135 typename std::remove_reference<decltype(out_field)>::type::Field_type;
136
137 if constexpr (std::is_same<t1, t2>::value) {
138 field.overwrite(out_field, obj_cnt);
139 } else {
141 }
142 },
143 fields[idx].value,
144 pdat.fields[idx].value);
145 }
146 }
147
148 void PatchDataLayer::resize(u32 new_obj_cnt) {
149
150 for (auto &field_var : fields) {
151 field_var.visit([&](auto &field) {
152 field.resize(new_obj_cnt);
153 });
154 }
155 }
156
157 void PatchDataLayer::reserve(u32 new_obj_cnt) {
158
159 for (auto &field_var : fields) {
160 field_var.visit([&](auto &field) {
161 field.reserve(new_obj_cnt);
162 });
163 }
164 }
165
166 void PatchDataLayer::expand(u32 new_obj_cnt) {
167
168 for (auto &field_var : fields) {
169 field_var.visit([&](auto &field) {
170 field.expand(new_obj_cnt);
171 });
172 }
173 }
174
175 void PatchDataLayer::index_remap(sycl::buffer<u32> &index_map, u32 len) {
176
177 sham::DeviceBuffer<u32> dev_index_map(
178 index_map, len, shamsys::instance::get_compute_scheduler_ptr());
179
180 for (auto &field_var : fields) {
181 field_var.visit([&](auto &field) {
182 field.index_remap(dev_index_map, len);
183 });
184 }
185 }
186
187 void PatchDataLayer::index_remap_resize(sycl::buffer<u32> &index_map, u32 len) {
188 sham::DeviceBuffer<u32> dev_index_map(
189 index_map, len, shamsys::instance::get_compute_scheduler_ptr());
190
191 for (auto &field_var : fields) {
192 field_var.visit([&](auto &field) {
193 field.index_remap_resize(dev_index_map, len);
194 });
195 }
196 }
197
199 for (auto &field_var : fields) {
200 field_var.visit([&](auto &field) {
201 field.index_remap_resize(index_map, len);
202 });
203 }
204 }
205
206 void PatchDataLayer::keep_ids(sycl::buffer<u32> &index_map, u32 len) {
207 index_remap_resize(index_map, len);
208 }
209
210 void PatchDataLayer::keep_ids(sham::DeviceBuffer<u32> &index_map, u32 len) {
211 index_remap_resize(index_map, len);
212 }
213
215 for (auto &field_var : fields) {
216 field_var.visit([&](auto &field) {
217 field.remove_ids(indexes, len);
218 });
219 }
220 }
221
222 void PatchDataLayer::append_subset_to(
223 sycl::buffer<u32> &idxs_buf, u32 sz, PatchDataLayer &pdat) {
224 StackEntry stack_loc{};
225
226 for (u32 idx = 0; idx < fields.size(); idx++) {
227
228 std::visit(
229 [&](auto &field, auto &out_field) {
230 using t1 = typename std::remove_reference<decltype(field)>::type::Field_type;
231 using t2 =
232 typename std::remove_reference<decltype(out_field)>::type::Field_type;
233
234 if constexpr (std::is_same<t1, t2>::value) {
235 field.append_subset_to(idxs_buf, sz, out_field);
236 } else {
238 }
239 },
240 fields[idx].value,
241 pdat.fields[idx].value);
242 }
243 }
244
245 void PatchDataLayer::append_subset_to(const std::vector<u32> &idxs, PatchDataLayer &pdat) {
246 StackEntry stack_loc{};
247
248 for (u32 idx = 0; idx < fields.size(); idx++) {
249
250 std::visit(
251 [&](auto &field, auto &out_field) {
252 using t1 = typename std::remove_reference<decltype(field)>::type::Field_type;
253 using t2 =
254 typename std::remove_reference<decltype(out_field)>::type::Field_type;
255
256 if constexpr (std::is_same<t1, t2>::value) {
257 field.append_subset_to(idxs, out_field);
258 } else {
260 }
261 },
262 fields[idx].value,
263 pdat.fields[idx].value);
264 }
265 }
266
267 void PatchDataLayer::append_subset_to(
268 const sham::DeviceBuffer<u32> &idxs_buf, u32 sz, PatchDataLayer &pdat) const {
269 StackEntry stack_loc{};
270
271 for (u32 idx = 0; idx < fields.size(); idx++) {
272
273 std::visit(
274 [&](auto &field, auto &out_field) {
275 using t1 = typename std::remove_reference<decltype(field)>::type::Field_type;
276 using t2 =
277 typename std::remove_reference<decltype(out_field)>::type::Field_type;
278
279 if constexpr (std::is_same<t1, t2>::value) {
280 field.append_subset_to(idxs_buf, sz, out_field);
281 } else {
283 shambase::format(
284 "Mismatch in layout\n source layout = {}\n dest layout = {}",
285 pdl().get_description_str(),
286 pdat.pdl().get_description_str()));
287 }
288 },
289 fields[idx].value,
290 pdat.fields[idx].value);
291 }
292 }
293
294 void PatchDataLayer::serialize_buf(shamalgs::SerializeHelper &serializer) {
295 StackEntry stack_loc{};
296 for_each_field_any([&](auto &f) {
297 f.serialize_buf(serializer);
298 });
299 }
300
301 shamalgs::SerializeSize PatchDataLayer::serialize_buf_byte_size() {
303 for_each_field_any([&](auto &f) {
304 sum += f.serialize_buf_byte_size();
305 });
306 return sum;
307 }
308
309 PatchDataLayer PatchDataLayer::deserialize_buf(
310 shamalgs::SerializeHelper &serializer, const std::shared_ptr<PatchDataLayerLayout> &pdl) {
311 StackEntry stack_loc{};
312
313 return PatchDataLayer{pdl, [&](auto &pdat_fields) {
314 pdl->for_each_field_any([&](auto &field) {
315 using f_t =
316 typename std::remove_reference<decltype(field)>::type;
317 using base_t = typename f_t::field_T;
318
319 pdat_fields.push_back(
321 serializer, field.name, field.nvar)});
322 });
323 }};
324 }
325
326 void PatchDataLayer::fields_raz() {
327 for_each_field_any([&](auto &f) {
328 f.field_raz();
329 });
330 }
331
332 template<class T>
333 void PatchDataLayer::split_patchdata(
334 std::array<std::reference_wrapper<PatchDataLayer>, 8> pdats,
335 std::array<T, 8> min_box,
336 std::array<T, 8> max_box) {
337
338 StackEntry stack_loc{};
339
340 PatchDataField<T> &main_field = fields[0].get_if_ref_throw<T>();
341
342 // auto get_vec_idx = [&](T vmin, T vmax) -> std::vector<u32> {
343 // return main_field.get_elements_with_range(
344 // [&](T val, T vmin, T vmax) {
345 // return Patch::is_in_patch_converted(val, vmin, vmax);
346 // },
347 // vmin,
348 // vmax);
349 // };
350
351 auto get_vec_idx = [&](T vmin, T vmax) -> std::vector<u32> {
352 return main_field.get_ids_vec_where(
353 [&](const auto &acc, u32 idx, T vmin, T vmax) {
354 T val = acc[idx];
355 return Patch::is_in_patch_converted(val, vmin, vmax);
356 },
357 vmin,
358 vmax);
359 };
360
361 std::vector<u32> idx_p0 = get_vec_idx(min_box[0], max_box[0]);
362 std::vector<u32> idx_p1 = get_vec_idx(min_box[1], max_box[1]);
363 std::vector<u32> idx_p2 = get_vec_idx(min_box[2], max_box[2]);
364 std::vector<u32> idx_p3 = get_vec_idx(min_box[3], max_box[3]);
365 std::vector<u32> idx_p4 = get_vec_idx(min_box[4], max_box[4]);
366 std::vector<u32> idx_p5 = get_vec_idx(min_box[5], max_box[5]);
367 std::vector<u32> idx_p6 = get_vec_idx(min_box[6], max_box[6]);
368 std::vector<u32> idx_p7 = get_vec_idx(min_box[7], max_box[7]);
369
370 u32 el_cnt_new = idx_p0.size() + idx_p1.size() + idx_p2.size() + idx_p3.size()
371 + idx_p4.size() + idx_p5.size() + idx_p6.size() + idx_p7.size();
372
373 if (get_obj_cnt() != el_cnt_new) {
374
375 logger::err_ln(
376 "PatchData",
377 "error in patchdata split, the new element count doesn't match the old one");
378
379 logger::err_ln("PatchData", min_box[0], max_box[0]);
380 logger::err_ln("PatchData", min_box[1], max_box[1]);
381 logger::err_ln("PatchData", min_box[2], max_box[2]);
382 logger::err_ln("PatchData", min_box[3], max_box[3]);
383 logger::err_ln("PatchData", min_box[4], max_box[4]);
384 logger::err_ln("PatchData", min_box[5], max_box[5]);
385 logger::err_ln("PatchData", min_box[6], max_box[6]);
386 logger::err_ln("PatchData", min_box[7], max_box[7]);
387
388 T vmin = sham::min(min_box[0], min_box[1]);
389 vmin = sham::min(vmin, min_box[2]);
390 vmin = sham::min(vmin, min_box[3]);
391 vmin = sham::min(vmin, min_box[4]);
392 vmin = sham::min(vmin, min_box[5]);
393 vmin = sham::min(vmin, min_box[6]);
394 vmin = sham::min(vmin, min_box[7]);
395
396 T vmax = sham::max(max_box[0], max_box[1]);
397 vmax = sham::max(vmax, max_box[2]);
398 vmax = sham::max(vmax, max_box[3]);
399 vmax = sham::max(vmax, max_box[4]);
400 vmax = sham::max(vmax, max_box[5]);
401 vmax = sham::max(vmax, max_box[6]);
402 vmax = sham::max(vmax, max_box[7]);
403
404 main_field.check_err_range(
405 [&](T val, T vmin, T vmax) {
406 return Patch::is_in_patch_converted(val, vmin, vmax);
407 },
408 vmin,
409 vmax);
410 }
411
412 // TODO create a extract subpatch function
413
414 append_subset_to(idx_p0, pdats[0].get());
415 append_subset_to(idx_p1, pdats[1].get());
416 append_subset_to(idx_p2, pdats[2].get());
417 append_subset_to(idx_p3, pdats[3].get());
418 append_subset_to(idx_p4, pdats[4].get());
419 append_subset_to(idx_p5, pdats[5].get());
420 append_subset_to(idx_p6, pdats[6].get());
421 append_subset_to(idx_p7, pdats[7].get());
422 }
423
424#ifndef DOXYGEN
425 template void PatchDataLayer::split_patchdata(
426 std::array<std::reference_wrapper<PatchDataLayer>, 8> pdats,
427 std::array<f32_3, 8> min_box,
428 std::array<f32_3, 8> max_box);
429 template void PatchDataLayer::split_patchdata(
430 std::array<std::reference_wrapper<PatchDataLayer>, 8> pdats,
431 std::array<f64_3, 8> min_box,
432 std::array<f64_3, 8> max_box);
433 template void PatchDataLayer::split_patchdata(
434 std::array<std::reference_wrapper<PatchDataLayer>, 8> pdats,
435 std::array<u32_3, 8> min_box,
436 std::array<u32_3, 8> max_box);
437 template void PatchDataLayer::split_patchdata(
438 std::array<std::reference_wrapper<PatchDataLayer>, 8> pdats,
439 std::array<u64_3, 8> min_box,
440 std::array<u64_3, 8> max_box);
441 template void PatchDataLayer::split_patchdata(
442 std::array<std::reference_wrapper<PatchDataLayer>, 8> pdats,
443 std::array<i64_3, 8> min_box,
444 std::array<i64_3, 8> max_box);
445#endif
446
447} // namespace shamrock::patch
Header file for the patch struct and related function.
std::uint32_t u32
32 bit unsigned integer
std::uint64_t u64
64 bit unsigned integer
static PatchDataField deserialize_buf(shamalgs::SerializeHelper &serializer, std::string field_name, u32 nvar)
deserialize a field inverse of serialize_buf
std::vector< u32 > get_ids_vec_where(Lambdacd &&cd_true, Args &&...args)
Same function as.
A buffer allocated in USM (Unified Shared Memory)
PatchDataLayer container class, the layout is described in patchdata_layout.
void index_remap(sycl::buffer< u32 > &index_map, u32 len)
this function remaps the patchdatafield like so val[id] = val[index_map[id]] This function can be use...
void index_remap_resize(sycl::buffer< u32 > &index_map, u32 len)
this function remaps the patchdatafield like so val[id] = val[index_map[id]] This function can be use...
void remove_ids(const sham::DeviceBuffer< u32 > &indexes, u32 len)
remove some particles ids
void extract_element(u32 pidx, PatchDataLayer &out_pdat)
extract particle at index pidx and insert it in the provided vectors
This header file contains utility functions related to exception handling in the code.
void throw_with_loc(std::string message, SourceLocation loc=SourceLocation{})
Throw an exception and append the source location to it.
This file contains the definition for the stacktrace related functionality.
static bool is_in_patch_converted(sycl::vec< T, 3 > val, sycl::vec< T, 3 > min_val, sycl::vec< T, 3 > max_val)
check if particle is in the asked range, given the output of @convert_coord
Definition Patch.hpp:210
header file to manage sycl