Shamrock 2025.10.0
Astrophysical Code
Loading...
Searching...
No Matches
PatchDataField.hpp
Go to the documentation of this file.
1// -------------------------------------------------------//
2//
3// SHAMROCK code for hydrodynamics
4// Copyright (c) 2021-2026 Timothée David--Cléris <tim.shamrock@proton.me>
5// SPDX-License-Identifier: CeCILL Free Software License Agreement v2.1
6// Shamrock is licensed under the CeCILL 2.1 License, see LICENSE for more information
7//
8// -------------------------------------------------------//
9
10#pragma once
11
20#include "shambase/memory.hpp"
24#include "shamalgs/memory.hpp"
25#include "shamalgs/numeric.hpp"
34#include <array>
35#include <memory>
36#include <random>
37#include <set>
38#include <string>
39#include <utility>
40#include <vector>
41
42template<class T>
44
46 // constexpr utilities (using & constexpr vals)
48
49 static constexpr bool isprimitive = std::is_same<T, f32>::value || std::is_same<T, f64>::value
50 || std::is_same<T, u32>::value
51 || std::is_same<T, u64>::value;
52
53 // clang-format off
54 static constexpr bool is_in_type_list =
55 #define X(args) std::is_same<T, args>::value ||
56 XMAC_LIST_ENABLED_FIELD false
57 #undef X
58 ;
59
60 static_assert(
61 is_in_type_list,
62 "PatchDataField must be one of those types : "
63
64 #define X(args) #args " "
65 XMAC_LIST_ENABLED_FIELD
66 #undef X
67 );
68 // clang-format on
69
70 template<bool B, class Tb = void>
71 using enable_if_t = typename std::enable_if<B, Tb>;
72
73 using EnableIfPrimitive = enable_if_t<isprimitive>;
74
75 using EnableIfVec = enable_if_t<is_in_type_list && (!isprimitive)>;
76
78 // member fields
80
82
83 std::string field_name;
84
85 u32 nvar; // number of variable per object
86
87 inline void check_nvar() const {
88 if (nvar == 0) {
89 throw shambase::make_except_with_loc<std::runtime_error>("nvar is 0 is not allowed");
90 }
91 }
92
94 // Constructors
96
97 public:
98 inline PatchDataField(PatchDataField &&other) noexcept
99 : buf(std::move(other.buf)), field_name(std::move(other.field_name)),
100 nvar(std::move(other.nvar)) {} // move constructor
101
102 inline PatchDataField &operator=(PatchDataField &&other) noexcept {
103 buf = std::move(other.buf);
104 field_name = std::move(other.field_name);
105 nvar = std::move(other.nvar);
106
107 return *this;
108 } // move assignment
109
110 // TODO find a way to add particles easily cf setup require public vector
111
112 using Field_type = T;
113
114 inline PatchDataField(std::string name, u32 nvar)
115 : field_name(std::move(name)), nvar(nvar),
116 buf(0, shamsys::instance::get_compute_scheduler_ptr()) {
117 check_nvar();
118 };
119
120 inline PatchDataField(std::string name, u32 nvar, u32 obj_cnt)
121 : field_name(std::move(name)), nvar(nvar),
122 buf(obj_cnt * nvar, shamsys::instance::get_compute_scheduler_ptr()) {
123 check_nvar();
124 };
125
126 inline PatchDataField(const PatchDataField &other)
127 : field_name(other.field_name), nvar(other.nvar), buf(other.buf.copy()) {
128 check_nvar();
129 }
130
131 inline PatchDataField(sham::DeviceBuffer<T> &&moved_buf, std::string name, u32 nvar)
132 : field_name(name), nvar(nvar), buf(std::forward<sham::DeviceBuffer<T>>(moved_buf)) {
133 check_nvar();
134 }
135
136 inline PatchDataField(sycl::buffer<T> &&moved_buf, u32 obj_cnt, std::string name, u32 nvar)
137 : field_name(name), nvar(nvar), buf(std::forward<sycl::buffer<T>>(moved_buf),
138 obj_cnt * nvar,
139 shamsys::instance::get_compute_scheduler_ptr()) {
140 check_nvar();
141 }
142
143 PatchDataField &operator=(const PatchDataField &other) = delete;
144
146 // member functions
148
149 inline PatchDataField duplicate() const {
150 const PatchDataField &current = *this;
151 return PatchDataField(current);
152 }
153
154 inline PatchDataField duplicate(std::string new_name) const {
155 const PatchDataField &current = *this;
156 PatchDataField ret = PatchDataField(current);
157 ret.field_name = new_name;
158 return ret;
159 }
160
161 inline std::unique_ptr<PatchDataField> duplicate_to_ptr() const {
162 const PatchDataField &current = *this;
163 return std::make_unique<PatchDataField>(current);
164 }
165
166 inline sham::DeviceBuffer<T> &get_buf() { return buf; }
167 inline const sham::DeviceBuffer<T> &get_buf() const { return buf; }
168
169 [[nodiscard]] inline bool is_empty() const { return get_obj_cnt() == 0; }
170
171 [[nodiscard]] inline u64 memsize() const { return buf.get_mem_usage(); }
172
173 [[nodiscard]] inline const u32 &get_nvar() const { return nvar; }
174
175 [[nodiscard]] inline u32 get_obj_cnt() const {
176 size_t sz = buf.get_size();
177 if (sz % nvar != 0) {
179 "the size of the buffer ({}) is not a multiple of the number of variables ({})",
180 sz,
181 nvar));
182 }
183 return shambase::narrow_or_throw<u32>(sz / nvar);
184 }
185
195 [[nodiscard]] inline u32 get_val_cnt() const {
196 return shambase::narrow_or_throw<u32>(buf.get_size());
197 }
198
199 [[nodiscard]] inline const std::string &get_name() const { return field_name; }
200
201 // TODO add overflow check
202 void resize(u32 new_obj_cnt);
203
204 void reserve(u32 new_obj_cnt);
205
206 void expand(u32 obj_to_add);
207
208 void shrink(u32 obj_to_rem);
209
210 void insert_element(T v);
211
212 void apply_offset(T off);
213
214 void insert(const PatchDataField<T> &f2);
215
216 void overwrite(const PatchDataField<T> &f2, u32 obj_cnt);
217
218 void overwrite(const sham::DeviceBuffer<T> &f2, u32 len);
219
220 void override(sycl::buffer<T> &data, u32 cnt);
221
222 void override(std::vector<T> &data, u32 cnt);
223
224 void override(const T val);
225
226 inline void synchronize_buf() { buf.synchronize(); }
227
228 inline std::vector<T> copy_to_stdvec() {
229 auto tmp = buf.copy_to_stdvec();
230 tmp.resize(get_val_cnt());
231 return tmp;
232 }
233
235 // Span utilities
237
244 template<u32 nvar, bool is_pointer_access = shamrock::access_t_span>
249
257 StackEntry stack_loc{};
258 return get_span<shamrock::dynamic_nvar, shamrock::access_t_span>();
259 }
260
262 get_pointer_span() {
263 StackEntry stack_loc{};
264 return get_span<shamrock::dynamic_nvar, shamrock::access_t_pointer>();
265 }
266
268 // get_subsets utilities
270
280 template<class Lambdacd, class... Args>
281 inline std::set<u32> get_ids_set_where(Lambdacd &&cd_true, Args... args) {
282 StackEntry stack_loc{};
283 std::set<u32> idx_cd{};
284 if (get_obj_cnt() > 0) {
285 auto acc = get_buf().copy_to_stdvec();
286
287 for (u32 i = 0; i < get_obj_cnt(); i++) {
288 if (cd_true(acc, i * nvar, args...)) {
289 idx_cd.insert(i);
290 }
291 }
292 }
293 return idx_cd;
294 }
295
306 template<class Lambdacd, class... Args>
307 inline std::vector<u32> get_ids_vec_where(Lambdacd &&cd_true, Args &&...args) {
308 StackEntry stack_loc{};
309 std::vector<u32> idx_cd{};
310 if (get_obj_cnt() > 0) {
311 auto acc = buf.copy_to_stdvec();
312
313 for (u32 i = 0; i < get_obj_cnt(); i++) {
314 if (std::forward<Lambdacd>(cd_true)(acc, i * nvar, std::forward<Args>(args)...)) {
315 idx_cd.push_back(i);
316 }
317 }
318 }
319 return idx_cd;
320 }
321
332 template<class Lambdacd, class... Args>
333 inline std::tuple<std::optional<sycl::buffer<u32>>, u32> get_ids_buf_where(
334 Lambdacd &&cd_true, Args... args) {
335 StackEntry stack_loc{};
336
337 if (get_obj_cnt() > 0) {
338
339 // buffer of booleans to store result of the condition
340 sycl::buffer<u32> mask(get_obj_cnt());
341
342 sham::EventList depends_list;
343 const T *acc = buf.get_read_access(depends_list);
344
345 sham::DeviceQueue &q = shamsys::instance::get_compute_scheduler().get_queue();
346
347 auto e = q.submit(depends_list, [&, args...](sycl::handler &cgh) {
348 sycl::accessor acc_mask{mask, cgh, sycl::write_only, sycl::no_init};
349 u32 nvar_field = nvar;
350
351 shambase::parallel_for(
352 cgh, get_obj_cnt(), "PatchdataField::get_ids_buf_where", [=](u32 id) {
353 acc_mask[id] = cd_true(acc, id * nvar_field, args...);
354 });
355 });
356
357 buf.complete_event_state(e);
358
360 shamsys::instance::get_compute_queue(), mask, get_obj_cnt());
361 } else {
362 return {std::nullopt, 0};
363 }
364 }
365
366 template<class Lambdacd, class... Args>
367 inline sham::DeviceBuffer<u32> get_ids_where_recycle_buffer(
368 sham::DeviceBuffer<u32> &mask, Lambdacd &&cd_true, Args... args) const {
369 StackEntry stack_loc{};
370
371 auto dev_sched = shamsys::instance::get_compute_scheduler_ptr();
372 sham::DeviceQueue &q = shambase::get_check_ref(dev_sched).get_queue();
373
374 auto obj_cnt = get_obj_cnt();
375 if (obj_cnt > 0) {
376 // buffer of booleans to store result of the condition
377 mask.resize(obj_cnt);
378
380 q,
381 sham::MultiRef{buf},
382 sham::MultiRef{mask},
383 obj_cnt,
384 [=, nvar_field = nvar](u32 id, const T *__restrict acc, u32 *__restrict acc_mask) {
385 acc_mask[id] = cd_true(acc, id * nvar_field, args...);
386 });
387
388 return shamalgs::stream_compact(dev_sched, mask, obj_cnt);
389 } else {
390 return sham::DeviceBuffer<u32>(0, dev_sched);
391 }
392 }
393
404 template<class Lambdacd, class... Args>
405 inline sham::DeviceBuffer<u32> get_ids_where(Lambdacd &&cd_true, Args... args) const {
406 StackEntry stack_loc{};
407
408 auto dev_sched = shamsys::instance::get_compute_scheduler_ptr();
409 sham::DeviceQueue &q = shambase::get_check_ref(dev_sched).get_queue();
410
411 auto obj_cnt = get_obj_cnt();
412 if (obj_cnt > 0) {
413 // buffer of booleans to store result of the condition
414 sham::DeviceBuffer<u32> mask(obj_cnt, dev_sched);
415
417 q,
418 sham::MultiRef{buf},
419 sham::MultiRef{mask},
420 obj_cnt,
421 [=, nvar_field = nvar](u32 id, const T *__restrict acc, u32 *__restrict acc_mask) {
422 acc_mask[id] = cd_true(acc, id * nvar_field, args...);
423 });
424
425 return shamalgs::stream_compact(dev_sched, mask, obj_cnt);
426 } else {
427 return sham::DeviceBuffer<u32>(0, dev_sched);
428 }
429 }
430
431 template<class Lambdacd>
432 [[deprecated("please use one of the PatchDataField::get_ids_..._where functions instead")]]
433 std::vector<u32> get_elements_with_range(Lambdacd &&cd_true, T vmin, T vmax);
434
443 template<class LambdaCd>
444 [[deprecated("please use one of the PatchDataField::get_ids_..._where functions instead")]]
445 std::tuple<std::optional<sycl::buffer<u32>>, u32> get_elements_in_half_open(T vmin, T vmax);
446
447 template<class Lambdacd>
448 [[deprecated("please use one of the PatchDataField::get_ids_..._where functions instead")]]
449 std::unique_ptr<sycl::buffer<u32>> get_elements_with_range_buf(
450 Lambdacd &&cd_true, T vmin, T vmax);
451
453 //
455
456 template<class Lambdacd>
457 void check_err_range(Lambdacd &&cd_true, T vmin, T vmax, std::string add_log = "");
458
459 void extract_element(u32 pidx, PatchDataField<T> &to);
460 void extract_elements(const sham::DeviceBuffer<u32> &idxs, PatchDataField<T> &to);
461
462 bool check_field_match(PatchDataField<T> &f2);
463
464 inline void field_raz() {
465 shamlog_debug_ln("PatchDataField", "raz : ", field_name);
467 }
468
475 void append_subset_to(const std::vector<u32> &idxs, PatchDataField &pfield);
476 void append_subset_to(sycl::buffer<u32> &idxs_buf, u32 sz, PatchDataField &pfield);
477 void append_subset_to(
478 const sham::DeviceBuffer<u32> &idxs_buf, u32 sz, PatchDataField &pfield) const;
479
480 inline PatchDataField make_new_from_subset(sycl::buffer<u32> &idxs_buf, u32 sz) {
481 PatchDataField pfield(field_name, nvar);
482 append_subset_to(idxs_buf, sz, pfield);
483 return pfield;
484 }
485
486 inline PatchDataField make_new_from_subset(sham::DeviceBuffer<u32> &idxs_buf, u32 sz) {
487 PatchDataField pfield(field_name, nvar);
488 append_subset_to(idxs_buf, sz, pfield);
489 return pfield;
490 }
491
492 void gen_mock_data(u32 obj_cnt, std::mt19937 &eng);
493
504 void index_remap(sham::DeviceBuffer<u32> &index_map, u32 len);
505
507 void permut_vars(const std::vector<u32> &permut);
508
520 void index_remap_resize(sham::DeviceBuffer<u32> &index_map, u32 len);
521
528 void remove_ids(const sham::DeviceBuffer<u32> &indexes, u32 len);
529
536 void serialize_buf(shamalgs::SerializeHelper &serializer);
537
546 static PatchDataField deserialize_buf(
547 shamalgs::SerializeHelper &serializer, std::string field_name, u32 nvar);
548
554 shamalgs::SerializeSize serialize_buf_byte_size();
555
561 void serialize_full(shamalgs::SerializeHelper &serializer);
562
569 static PatchDataField deserialize_full(shamalgs::SerializeHelper &serializer);
570
576 shamalgs::SerializeSize serialize_full_byte_size();
577
578 T compute_max() const;
579 T compute_min() const;
580 T compute_sum() const;
581
582 shambase::VecComponent<T> compute_dot_sum();
583
584 bool has_nan();
585 bool has_inf();
586 bool has_nan_or_inf();
587
589 // static member functions
591
592 static PatchDataField<T> mock_field(u64 seed, u32 obj_cnt, std::string name, u32 nvar);
593 static PatchDataField<T> mock_field(
594 u64 seed, u32 obj_cnt, std::string name, u32 nvar, T vmin, T vmax);
595};
596
597// TODO add overflow check
598template<class T>
599inline void PatchDataField<T>::resize(u32 new_obj_cnt) {
600 buf.resize(shambase::narrow_or_throw<u32>(u64(new_obj_cnt) * u64(nvar)));
601}
602
603template<class T>
604inline void PatchDataField<T>::reserve(u32 new_obj_cnt) {
605
606 u32 add_cnt = shambase::narrow_or_throw<u32>(u64(new_obj_cnt) * u64(nvar));
607 buf.reserve(add_cnt);
608}
609
610template<class T>
611inline void PatchDataField<T>::expand(u32 obj_to_add) {
612 resize(get_obj_cnt() + obj_to_add);
613}
614
615template<class T>
616inline void PatchDataField<T>::shrink(u32 obj_to_rem) {
617
618 if (obj_to_rem > get_obj_cnt()) {
619
621 "impossible to remove more object than there is in the patchdata field");
622 }
623
624 resize(get_obj_cnt() - obj_to_rem);
625}
626
627template<class T>
628inline void PatchDataField<T>::overwrite(const PatchDataField<T> &f2, u32 obj_cnt) {
629 StackEntry stack_loc{};
630 buf.copy_from(f2.buf, obj_cnt * f2.nvar);
631}
632
633template<class T>
634inline void PatchDataField<T>::overwrite(const sham::DeviceBuffer<T> &f2, u32 len) {
635 StackEntry stack_loc{};
636 buf.copy_from(f2, len);
637}
638
639template<class T>
640inline void PatchDataField<T>::override(sycl::buffer<T> &data, u32 cnt) {
641 StackEntry stack_loc{};
642 buf.copy_from_sycl_buffer(data, cnt);
643}
644
645template<class T>
646inline void PatchDataField<T>::override(std::vector<T> &data, u32 cnt) {
647 StackEntry stack_loc{};
648 buf.copy_from_stdvec(data, cnt);
649}
650
651template<class T>
652inline void PatchDataField<T>::override(const T val) {
653 StackEntry stack_loc{};
654 buf.fill(val);
655}
656
657template<class T>
658template<class Lambdacd>
659inline std::vector<u32> PatchDataField<T>::get_elements_with_range(
660 Lambdacd &&cd_true, T vmin, T vmax) {
661 StackEntry stack_loc{};
662 std::vector<u32> idxs;
663
664 /* Possible GPU version
665 sycl::buffer<u32> valid {size()};
666
667 shamsys::instance::get_compute_queue().submit([&](sycl::handler & cgh){
668 sycl::accessor acc {shambase::get_check_ref(get_buf()), cgh, sycl::read_only};
669 sycl::accessor bools {valid, cgh,sycl::write_only,sycl::no_init};
670
671 shambase::parallel_for(cgh,size(),"get_element_with_range",[=](u32 i){
672 bools[i] = (cd_true(acc[i], vmin, vmax)) ? 1 : 0;
673 });
674
675 });
676
677 std::tuple<std::optional<sycl::buffer<u32>>, u32> ret =
678 shamalgs::numeric::stream_compact(shamsys::instance::get_compute_queue(), valid, size());
679
680 std::vector<u32> idxs;
681
682 {
683 if(std::get<0>(ret).has_value()){
684 idxs = shamalgs::memory::buf_to_vec(*std::get<0>(ret), std::get<1>(ret));
685 }
686 }
687 */
688
689 if (nvar != 1) {
691 }
692
693 {
694 auto acc = buf.copy_to_stdvec();
695
696 for (u32 i = 0; i < get_val_cnt(); i++) {
697 if (cd_true(acc[i], vmin, vmax)) {
698 idxs.push_back(i);
699 }
700 }
701 }
702
703 return idxs;
704}
705
706template<class T>
707template<class Lambdacd>
708inline std::unique_ptr<sycl::buffer<u32>> PatchDataField<T>::get_elements_with_range_buf(
709 Lambdacd &&cd_true, T vmin, T vmax) {
710 std::vector<u32> idxs = get_elements_with_range(std::forward<Lambdacd>(cd_true), vmin, vmax);
711 if (idxs.empty()) {
712 return {};
713 } else {
714 return std::make_unique<sycl::buffer<u32>>(shamalgs::memory::vec_to_buf(idxs));
715 }
716}
717
718class PatchDataRangeCheckError : public std::exception {
719 public:
720 explicit PatchDataRangeCheckError(const char *message) : msg_(message) {}
721
722 explicit PatchDataRangeCheckError(const std::string &message) : msg_(message) {}
723
724 ~PatchDataRangeCheckError() noexcept override = default;
725
726 [[nodiscard]] const char *what() const noexcept override { return msg_.c_str(); }
727
728 protected:
729 std::string msg_;
730};
731
732template<class T>
733template<class Lambdacd>
735 Lambdacd &&cd_true, T vmin, T vmax, std::string add_log) {
736 StackEntry stack_loc{};
737
738 if (is_empty()) {
739 return;
740 }
741
742 if (nvar != 1) {
744 }
745
746 bool error = false;
747 {
748 auto acc = buf.copy_to_stdvec();
749 u32 err_cnt = 0;
750
751 for (u32 i = 0; i < get_val_cnt(); i++) {
752 if (!cd_true(acc[i], vmin, vmax)) {
753 logger::err_ln(
754 "PatchDataField",
755 "obj =",
756 i,
757 "->",
758 acc[i],
759 "not in range [",
760 vmin,
761 ",",
762 vmax,
763 "]");
764 error = true;
765 err_cnt++;
766 if (err_cnt > 50) {
767 logger::err_ln("PatchDataField", "...");
768 break;
769 }
770 }
771 }
772 }
773
774 if (error) {
775 logger::err_ln("PatchDataField", "additional infos :", add_log);
777 }
778}
Header file describing a Node Instance.
std::uint32_t u32
32 bit unsigned integer
std::uint64_t u64
64 bit unsigned integer
shamrock::PatchDataFieldSpan< T, shamrock::dynamic_nvar > get_span_nvar_dynamic()
Returns a shamrock::PatchDataFieldSpan pointing to the current PatchDataField.
shamrock::PatchDataFieldSpan< T, nvar, is_pointer_access > get_span()
Returns a shamrock::PatchDataFieldSpan pointing to the current PatchDataField.
u32 get_val_cnt() const
Get the number of values stored in the field.
sham::DeviceBuffer< u32 > get_ids_where(Lambdacd &&cd_true, Args... args) const
Same function as.
std::tuple< std::optional< sycl::buffer< u32 > >, u32 > get_ids_buf_where(Lambdacd &&cd_true, Args... args)
Same function as.
std::set< u32 > get_ids_set_where(Lambdacd &&cd_true, Args... args)
Get the ids set where object.
std::vector< u32 > get_ids_vec_where(Lambdacd &&cd_true, Args &&...args)
Same function as.
std::tuple< std::optional< sycl::buffer< u32 > >, u32 > get_elements_in_half_open(T vmin, T vmax)
Get the indices of the elements in half open interval.
A buffer allocated in USM (Unified Shared Memory)
void resize(size_t new_size, bool keep_data=true)
Resizes the buffer to a given size.
A SYCL queue associated with a device and a context.
sycl::event submit(Fct &&fct)
Submits a kernel to the SYCL queue.
DeviceQueue & get_queue(u32 id=0)
Get a reference to a DeviceQueue.
Class to manage a list of SYCL events.
Definition EventList.hpp:31
Represents a span of data within a PatchDataField.
This header file contains utility functions related to exception handling in the code.
void kernel_call(sham::DeviceQueue &q, RefIn in, RefOut in_out, u32 n, Functor &&func, SourceLocation &&callsite=SourceLocation{})
Submit a kernel to a SYCL queue.
sycl::buffer< T > index_remap(sycl::queue &q, sycl::buffer< T > &source_buf, sycl::buffer< u32 > &index_map, u32 len)
remap a buffer according to a given index map result[i] = result[index_map[i]]
Definition algorithm.cpp:32
sycl::buffer< T > vec_to_buf(const std::vector< T > &buf)
Convert a std::vector to a sycl::buffer
Definition memory.cpp:29
std::tuple< std::optional< sycl::buffer< u32 > >, u32 > stream_compact(sycl::queue &q, sycl::buffer< u32 > &buf_flags, u32 len)
Stream compaction algorithm.
Definition numeric.cpp:84
void append_subset_to(const sham::DeviceBuffer< T > &buf, const sham::DeviceBuffer< u32 > &idxs_buf, u32 nvar, sham::DeviceBuffer< T > &buf_other, u32 start_enque)
Appends a subset of elements from one buffer to another.
void throw_with_loc(std::string message, SourceLocation loc=SourceLocation{})
Throw an exception and append the source location to it.
T & get_check_ref(const std::unique_ptr< T > &ptr, SourceLocation loc=SourceLocation())
Takes a std::unique_ptr and returns a reference to the object it holds. It throws a std::runtime_erro...
Definition memory.hpp:110
void throw_unimplemented(SourceLocation loc=SourceLocation{})
Throw a std::runtime_error saying that the function is unimplemented.
std::vector< std::string_view > args
Executable argument list (mapped from argv)
Definition cmdopt.cpp:63
Utilities for safe type narrowing conversions.
main include file for memory algorithms
This file contains the definition for the stacktrace related functionality.
A class that references multiple buffers or similar objects.