Shamrock 2025.10.0
Astrophysical Code
Loading...
Searching...
No Matches
PatchDataLayer.hpp
Go to the documentation of this file.
1// -------------------------------------------------------//
2//
3// SHAMROCK code for hydrodynamics
4// Copyright (c) 2021-2026 Timothée David--Cléris <tim.shamrock@proton.me>
5// SPDX-License-Identifier: CeCILL Free Software License Agreement v2.1
6// Shamrock is licensed under the CeCILL 2.1 License, see LICENSE for more information
7//
8// -------------------------------------------------------//
9
10#pragma once
11
19
21#include "shambase/memory.hpp"
23#include "Patch.hpp"
24#include "PatchDataField.hpp"
28#include <variant>
29#include <vector>
30
31namespace shamrock::patch {
35 class PatchDataLayer {
36
37 void init_fields();
38
39 using var_t = FieldVariant<PatchDataField>;
40
41 std::vector<var_t> fields;
42 std::shared_ptr<PatchDataLayerLayout> pdl_ptr;
43
44 inline var_t &get_field_variant(u32 idx) {
45 if (idx >= fields.size()) {
47 "the requested field index is out of bounds\n"
48 " current map is : \n"
49 + pdl().get_description_str()
50 + "\n"
51 " arg : idx = "
52 + std::to_string(idx));
53 }
54 return fields[idx];
55 }
56
57 inline const var_t &get_field_variant(u32 idx) const {
58 if (idx >= fields.size()) {
60 "the requested field index is out of bounds\n"
61 " current map is : \n"
62 + pdl().get_description_str()
63 + "\n"
64 " arg : idx = "
65 + std::to_string(idx));
66 }
67 return fields[idx];
68 }
69
70 public:
71 using field_variant_t = var_t;
72
73 inline PatchDataLayerLayout &pdl() { return shambase::get_check_ref(pdl_ptr); }
74 inline const PatchDataLayerLayout &pdl() const { return shambase::get_check_ref(pdl_ptr); }
75
76 inline std::shared_ptr<PatchDataLayerLayout> get_layout_ptr() const { return pdl_ptr; }
77
78 inline PatchDataLayer(const std::shared_ptr<PatchDataLayerLayout> &pdl) : pdl_ptr(pdl) {
79 init_fields();
80 }
81
82 inline PatchDataLayer(const PatchDataLayer &other) : pdl_ptr(other.get_layout_ptr()) {
83
84 NamedStackEntry stack_loc{"PatchDataLayer::copy_constructor", true};
85
86 for (auto &field_var : other.fields) {
87
88 field_var.visit([&](auto &field) {
89 using base_t =
90 typename std::remove_reference<decltype(field)>::type::Field_type;
91 fields.emplace_back(PatchDataField<base_t>(field));
92 });
93 };
94 }
95
101 inline PatchDataLayer(PatchDataLayer &&other) noexcept
102 : fields(std::move(other.fields)), pdl_ptr(std::move(other.pdl_ptr)) {}
103
109 inline PatchDataLayer &operator=(PatchDataLayer &&other) noexcept {
110 fields = std::move(other.fields);
111 pdl_ptr = std::move(other.pdl_ptr);
112 return *this;
113 }
114
115 PatchDataLayer &operator=(const PatchDataLayer &other) = delete;
116
117 static PatchDataLayer mock_patchdata(
118 u64 seed, u32 obj_cnt, const std::shared_ptr<PatchDataLayerLayout> &pdl);
119
120 template<class Functor>
121 inline void for_each_field_any(Functor &&func) {
122 for (auto &f : fields) {
123 f.visit([&](auto &arg) {
124 func(arg);
125 });
126 }
127 }
128
129 template<class Functor>
130 inline void for_each_field_any(Functor &&func) const {
131 for (auto &f : fields) {
132 f.visit([&](const auto &arg) {
133 func(arg);
134 });
135 }
136 }
137
138 template<class Func>
139 inline PatchDataLayer(const std::shared_ptr<PatchDataLayerLayout> &pdl, Func &&fct_init)
140 : pdl_ptr(pdl) {
141
142 u32 cnt = 0;
143
144 fct_init(fields);
145 }
146
147 inline PatchDataLayer duplicate() {
148 const PatchDataLayer &current = *this;
149 return PatchDataLayer(current);
150 }
151
152 inline std::unique_ptr<PatchDataLayer> duplicate_to_ptr() {
153 const PatchDataLayer &current = *this;
154 return std::make_unique<PatchDataLayer>(current);
155 }
156
163 void extract_element(u32 pidx, PatchDataLayer &out_pdat);
164
165 void extract_elements(const sham::DeviceBuffer<u32> &idxs, PatchDataLayer &out_pdat);
166
167 void keep_ids(sycl::buffer<u32> &index_map, u32 len);
168
169 void insert_elements(const PatchDataLayer &pdat);
170
179 template<class T>
180 void insert_elements_in_range(PatchDataLayer &pdat, T bmin, T bmax);
181
182 void resize(u32 new_obj_cnt);
183
184 void reserve(u32 new_obj_cnt);
185
186 void expand(u32 obj_cnt);
187
196 void index_remap(sycl::buffer<u32> &index_map, u32 len);
197
206 void index_remap_resize(sycl::buffer<u32> &index_map, u32 len);
207
209 void index_remap_resize(sham::DeviceBuffer<u32> &index_map, u32 len);
210
212 void keep_ids(sham::DeviceBuffer<u32> &index_map, u32 len);
213
215 void remove_ids(const sham::DeviceBuffer<u32> &indexes, u32 len);
216
217 // template<class Tvecbox>
218 // void split_patchdata(PatchDataLayer & pd0,PatchDataLayer & pd1,PatchDataLayer &
219 // pd2,PatchDataLayer & pd3,PatchDataLayer & pd4,PatchDataLayer & pd5,PatchDataLayer &
220 // pd6,PatchDataLayer & pd7,
221 // Tvecbox bmin_p0,Tvecbox bmin_p1,Tvecbox bmin_p2,Tvecbox bmin_p3,Tvecbox
222 // bmin_p4,Tvecbox bmin_p5,Tvecbox bmin_p6,Tvecbox bmin_p7, Tvecbox bmax_p0,Tvecbox
223 // bmax_p1,Tvecbox bmax_p2,Tvecbox bmax_p3,Tvecbox bmax_p4,Tvecbox bmax_p5,Tvecbox
224 // bmax_p6,Tvecbox bmax_p7);
225
226 template<class Tvecbox>
227 void split_patchdata(
228 std::array<std::reference_wrapper<PatchDataLayer>, 8> pdats,
229 std::array<Tvecbox, 8> min_box,
230 std::array<Tvecbox, 8> max_box);
231
232 void append_subset_to(const std::vector<u32> &idxs, PatchDataLayer &pdat);
233 void append_subset_to(sycl::buffer<u32> &idxs_buf, u32 sz, PatchDataLayer &pdat);
234 void append_subset_to(
235 const sham::DeviceBuffer<u32> &idxs_buf, u32 sz, PatchDataLayer &pdat) const;
236
237 inline u32 get_obj_cnt() const {
238
239 bool is_empty = fields.empty();
240
241 if (!is_empty) {
242 return fields[0].visit_return([](const auto &field) {
243 return field.get_obj_cnt();
244 });
245 }
246
248 "this PatchDataLayer does not contain any fields");
249 }
250
251 inline u64 memsize() {
252 u64 sum = 0;
253
254 for (auto &field_var : fields) {
255
256 field_var.visit([&](auto &field) {
257 sum += field.memsize();
258 });
259 }
260
261 return sum;
262 }
263
264 inline bool is_empty() { return get_obj_cnt() == 0; }
265
266 void synchronize_buf() {
267 for (auto &field_var : fields) {
268 field_var.visit([&](auto &field) {
269 field.synchronize_buf();
270 });
271 }
272 }
273
274 void overwrite(PatchDataLayer &pdat, u32 obj_cnt);
275
276 template<class T>
277 bool check_field_type(u32 idx) {
278 var_t &tmp = get_field_variant(idx);
279
280 PatchDataField<T> *pval = std::get_if<PatchDataField<T>>(&tmp.value);
281
282 if (pval) {
283 return true;
284 } else {
285 return false;
286 }
287 }
288
289 template<class T>
290 PatchDataField<T> &get_field(u32 idx) {
291
292 var_t &tmp = get_field_variant(idx);
293
294 PatchDataField<T> *pval = std::get_if<PatchDataField<T>>(&tmp.value);
295
296 if (pval) {
297 return *pval;
298 }
299
301 "the request id is not of correct type\n"
302 " current map is : \n"
303 + pdl().get_description_str()
304 + "\n"
305 " arg : idx = "
306 + std::to_string(idx));
307 }
308
309 template<class T>
310 const PatchDataField<T> &get_field(u32 idx) const {
311
312 const var_t &tmp = get_field_variant(idx);
313
314 const PatchDataField<T> *pval = std::get_if<PatchDataField<T>>(&tmp.value);
315
316 if (pval) {
317 return *pval;
318 }
319
321 "the request id is not of correct type\n"
322 " current map is : \n"
323 + pdl().get_description_str()
324 + "\n"
325 " arg : idx = "
326 + std::to_string(idx));
327 }
328
329 template<class T>
330 PatchDataField<T> &get_field(const std::string &field_name) {
331 return get_field<T>(pdl().get_field_idx<T>(field_name));
332 }
333
334 template<class T>
335 const PatchDataField<T> &get_field(const std::string &field_name) const {
336 return get_field<T>(pdl().get_field_idx<T>(field_name));
337 }
338
339 template<class T>
340 sham::DeviceBuffer<T> &get_field_buf_ref(u32 idx) {
341
342 var_t &tmp = get_field_variant(idx);
343
344 PatchDataField<T> *pval = std::get_if<PatchDataField<T>>(&tmp.value);
345
346 if (pval) {
347 return pval->get_buf();
348 }
349
351 "the request id is not of correct type\n"
352 " current map is : \n"
353 + pdl().get_description_str()
354 + "\n"
355 " arg : idx = "
356 + std::to_string(idx));
357 }
358
365 template<class T, u32 nvar>
367 return get_field<T>(idx).template get_span<nvar>();
368 }
369
377 template<class T>
379 return get_field<T>(idx).get_span_nvar_dynamic();
380 }
381
382 template<class T>
384 get_field_pointer_span(u32 idx) {
385 return get_field<T>(idx).get_pointer_span();
386 }
387
393 u32 cnt = get_obj_cnt();
394 for (auto &field_var : fields) {
395 field_var.visit([&](auto &field) {
396 if (field.get_obj_cnt() != cnt) {
398 "mismatch in obj cnt");
399 }
400 });
401 }
402 }
403
404 // template<class T> inline std::vector<PatchDataField<T> & > get_field_list(){
405 // std::vector<PatchDataField<T> & > ret;
406 //
407 //
408 //}
409
410 template<class T, class Functor>
411 inline void for_each_field(Functor &&func) {
412 for (auto &f : fields) {
413 PatchDataField<T> *pval = std::get_if<PatchDataField<T>>(&f.value);
414
415 if (pval) {
416 func(*pval);
417 }
418 }
419 }
420
421 inline friend bool operator==(PatchDataLayer &p1, PatchDataLayer &p2) {
422 bool check = true;
423
424 if (p1.fields.size() != p2.fields.size()) {
425 return false;
426 }
427
428 for (u32 idx = 0; idx < p1.fields.size(); idx++) {
429
430 bool ret = std::visit(
431 [&](auto &pf1, auto &pf2) -> bool {
432 using t1 = typename std::remove_reference<decltype(pf1)>::type::Field_type;
433 using t2 = typename std::remove_reference<decltype(pf2)>::type::Field_type;
434
435 if constexpr (std::is_same<t1, t2>::value) {
436 return pf1.check_field_match(pf2);
437 } else {
438 return false;
439 }
440 },
441 p1.fields[idx].value,
442 p2.fields[idx].value);
443
444 check = check && ret;
445 }
446
447 return check;
448 }
449
450 void serialize_buf(shamalgs::SerializeHelper &serializer);
451
452 shamalgs::SerializeSize serialize_buf_byte_size();
453
454 static PatchDataLayer deserialize_buf(
455 shamalgs::SerializeHelper &serializer,
456 const std::shared_ptr<PatchDataLayerLayout> &pdl);
457
458 void fields_raz();
459
460 bool has_nan() {
461 StackEntry stack_loc{};
462
463 bool ret = false;
464
465 for (auto &field_var : fields) {
466 field_var.visit([&](auto &field) {
467 if (field.has_nan()) {
468 ret = true;
469 }
470 });
471 }
472 return ret;
473 }
474 bool has_inf() {
475 StackEntry stack_loc{};
476
477 bool ret = false;
478
479 for (auto &field_var : fields) {
480 field_var.visit([&](auto &field) {
481 if (field.has_inf()) {
482 ret = true;
483 }
484 });
485 }
486 return ret;
487 }
488 bool has_nan_or_inf() {
489 StackEntry stack_loc{};
490
491 bool ret = false;
492
493 for (auto &field_var : fields) {
494 field_var.visit([&](auto &field) {
495 if (field.has_nan_or_inf()) {
496 ret = true;
497 }
498 });
499 }
500 return ret;
501 }
502
511 template<class T>
512 void override_patch_field(std::string field_name, std::vector<T> &vec) {
513 u32 len = vec.size();
514 PatchDataField<T> &f = get_field<T>(pdl().get_field_idx<T>(field_name));
515 sycl::buffer<T> buf(vec.data(), len);
516 f.override(buf, len);
517 }
518
529 template<class T>
530 inline std::vector<T> fetch_data(std::string key) {
531
532 std::vector<T> vec;
533
534 auto appender = [&](auto &field) {
535 if (field.get_name() == key) {
536
537 shamlog_debug_ln("PyShamrockCTX", "appending field", key);
538
539 if (!field.is_empty()) {
540 auto acc = field.get_buf().copy_to_stdvec();
541 u32 len = field.get_val_cnt();
542
543 for (u32 i = 0; i < len; i++) {
544 vec.push_back(acc[i]);
545 }
546 }
547 }
548 };
549
550 for_each_field<T>([&](auto &field) {
551 appender(field);
552 });
553
554 return vec;
555 }
556 };
557
558 template<class T>
559 inline void PatchDataLayer::insert_elements_in_range(PatchDataLayer &pdat, T bmin, T bmax) {
560
561 StackEntry stack_loc{};
562
563 if (!pdl().check_main_field_type<T>()) {
564
566 "the chosen type for the main field does not match the required template type");
567 }
568
569 PatchDataField<T> &main_field = pdat.get_field<T>(0);
570
571 // Note that using get_ids_vec_where here is safe since nvar for main_field is equal to 1
572 // hence the Lambda cd_true will be applied to each block on the patch. e.g : i * nvar = i
573 auto get_vec_idx = [&](T vmin, T vmax) -> std::vector<u32> {
574 return main_field.get_ids_vec_where(
575 [&](const auto &acc, u32 idx, T vmin, T vmax) {
576 if (shambase::VectorProperties<T>::dimension == 3) {
577 T val = acc[idx];
578 return shammath::is_in_half_open(val, vmin, vmax);
579 } else {
581 "dimension != 3 is not handled");
582 }
583 },
584 vmin,
585 vmax);
586 };
587
588 // auto get_vec_idx = [&](T vmin, T vmax) -> std::vector<u32> {
589 // return main_field.get_elements_with_range(
590 // [&](T val, T vmin, T vmax) {
591 // if (shambase::VectorProperties<T>::dimension == 3) {
592 // return shammath::is_in_half_open(val, vmin, vmax);
593 // } else {
594 // throw shambase::make_except_with_loc<std::runtime_error>(
595 // "dimension != 3 is not handled");
596 // }
597 // },
598 // vmin,
599 // vmax);
600 // };
601
602 std::vector<u32> idx_lst = get_vec_idx(bmin, bmax);
603
604 shamlog_debug_sycl_ln("PatchDataLayer", "inserting element cnt =", idx_lst.size());
605
606 pdat.append_subset_to(idx_lst, *this);
607 }
608
609} // namespace shamrock::patch
Header file for the patch struct and related function.
std::uint32_t u32
32 bit unsigned integer
std::uint64_t u64
64 bit unsigned integer
std::vector< u32 > get_ids_vec_where(Lambdacd &&cd_true, Args &&...args)
Same function as.
Represents a span of data within a PatchDataField.
PatchDataLayer container class, the layout is described in patchdata_layout.
void override_patch_field(std::string field_name, std::vector< T > &vec)
void index_remap(sycl::buffer< u32 > &index_map, u32 len)
this function remaps the patchdatafield like so val[id] = val[index_map[id]] This function can be use...
PatchDataLayer & operator=(PatchDataLayer &&other) noexcept
PatchDataLayer move assignment.
void check_field_obj_cnt_match()
check that all contained field have the same obj cnt
PatchDataFieldSpan< T, shamrock::dynamic_nvar > get_field_span_nvar_dynamic(u32 idx)
returns a PatchDataFieldSpan of the field at index idx, with a dynamic number of variables
PatchDataLayer(PatchDataLayer &&other) noexcept
PatchDataLayer move constructor.
std::vector< T > fetch_data(std::string key)
Fetch data of a patchdata field into a std::vector.
void index_remap_resize(sycl::buffer< u32 > &index_map, u32 len)
this function remaps the patchdatafield like so val[id] = val[index_map[id]] This function can be use...
void remove_ids(const sham::DeviceBuffer< u32 > &indexes, u32 len)
remove some particles ids
void extract_element(u32 pidx, PatchDataLayer &out_pdat)
extract particle at index pidx and insert it in the provided vectors
PatchDataFieldSpan< T, nvar > get_field_span(u32 idx)
returns a PatchDataFieldSpan of the field at index idx, with the given nvar value
void insert_elements_in_range(PatchDataLayer &pdat, T bmin, T bmax)
insert elements of pdat only if they are within the range
This header file contains utility functions related to exception handling in the code.
T & get_check_ref(const std::unique_ptr< T > &ptr, SourceLocation loc=SourceLocation())
Takes a std::unique_ptr and returns a reference to the object it holds. It throws a std::runtime_erro...
Definition memory.hpp:110
ExcptTypes make_except_with_loc(std::string message, SourceLocation loc=SourceLocation{})
Create an exception with a message and a location.
bool is_in_half_open(T val, T min, T max)
return true if val is in [min,max[
Definition intervals.hpp:36
This file contains the definition for the stacktrace related functionality.
shambase::details::NamedBasicStackEntry NamedStackEntry
Alias for shambase::details::NamedBasicStackEntry.
shambase::details::BasicStackEntry StackEntry
Alias for shambase::details::BasicStackEntry.