49 static constexpr bool isprimitive = std::is_same<T, f32>::value || std::is_same<T, f64>::value
50 || std::is_same<T, u32>::value
51 || std::is_same<T, u64>::value;
54 static constexpr bool is_in_type_list =
55 #define X(args) std::is_same<T, args>::value ||
56 XMAC_LIST_ENABLED_FIELD
false
62 "PatchDataField must be one of those types : "
64 #define X(args) #args " "
65 XMAC_LIST_ENABLED_FIELD
70 template<
bool B,
class Tb =
void>
71 using enable_if_t =
typename std::enable_if<B, Tb>;
73 using EnableIfPrimitive = enable_if_t<isprimitive>;
75 using EnableIfVec = enable_if_t<is_in_type_list && (!isprimitive)>;
83 std::string field_name;
87 inline void check_nvar()
const {
98 inline PatchDataField(PatchDataField &&other) noexcept
99 : buf(std::move(other.buf)), field_name(std::move(other.field_name)),
100 nvar(std::move(other.nvar)) {}
102 inline PatchDataField &operator=(PatchDataField &&other)
noexcept {
103 buf = std::move(other.buf);
104 field_name = std::move(other.field_name);
105 nvar = std::move(other.nvar);
112 using Field_type = T;
114 inline PatchDataField(std::string name,
u32 nvar)
115 : field_name(std::move(name)), nvar(nvar),
116 buf(0, shamsys::instance::get_compute_scheduler_ptr()) {
120 inline PatchDataField(std::string name,
u32 nvar,
u32 obj_cnt)
121 : field_name(std::move(name)), nvar(nvar),
122 buf(obj_cnt * nvar, shamsys::instance::get_compute_scheduler_ptr()) {
126 inline PatchDataField(
const PatchDataField &other)
127 : field_name(other.field_name), nvar(other.nvar), buf(other.buf.copy()) {
136 inline PatchDataField(sycl::buffer<T> &&moved_buf,
u32 obj_cnt, std::string name,
u32 nvar)
137 : field_name(name), nvar(nvar), buf(std::forward<sycl::buffer<T>>(moved_buf),
139 shamsys::instance::get_compute_scheduler_ptr()) {
143 PatchDataField &operator=(
const PatchDataField &other) =
delete;
149 inline PatchDataField duplicate()
const {
150 const PatchDataField ¤t = *
this;
151 return PatchDataField(current);
154 inline PatchDataField duplicate(std::string new_name)
const {
155 const PatchDataField ¤t = *
this;
156 PatchDataField ret = PatchDataField(current);
157 ret.field_name = new_name;
161 inline std::unique_ptr<PatchDataField> duplicate_to_ptr()
const {
162 const PatchDataField ¤t = *
this;
163 return std::make_unique<PatchDataField>(current);
169 [[nodiscard]]
inline bool is_empty()
const {
return get_obj_cnt() == 0; }
171 [[nodiscard]]
inline u64 memsize()
const {
return buf.get_mem_usage(); }
173 [[nodiscard]]
inline const u32 &get_nvar()
const {
return nvar; }
175 [[nodiscard]]
inline u32 get_obj_cnt()
const {
176 size_t sz = buf.get_size();
177 if (sz % nvar != 0) {
179 "the size of the buffer ({}) is not a multiple of the number of variables ({})",
183 return shambase::narrow_or_throw<u32>(sz / nvar);
196 return shambase::narrow_or_throw<u32>(buf.get_size());
199 [[nodiscard]]
inline const std::string &get_name()
const {
return field_name; }
202 void resize(
u32 new_obj_cnt);
204 void reserve(
u32 new_obj_cnt);
206 void expand(
u32 obj_to_add);
208 void shrink(
u32 obj_to_rem);
210 void insert_element(T v);
212 void apply_offset(T off);
214 void insert(
const PatchDataField<T> &f2);
216 void overwrite(
const PatchDataField<T> &f2,
u32 obj_cnt);
218 void overwrite(
const sham::DeviceBuffer<T> &f2,
u32 len);
220 void override(sycl::buffer<T> &data,
u32 cnt);
222 void override(std::vector<T> &data,
u32 cnt);
224 void override(
const T val);
226 inline void synchronize_buf() { buf.synchronize(); }
228 inline std::vector<T> copy_to_stdvec() {
229 auto tmp = buf.copy_to_stdvec();
244 template<u32 nvar,
bool is_po
inter_access = shamrock::access_t_span>
280 template<
class Lambdacd,
class... Args>
283 std::set<u32> idx_cd{};
284 if (get_obj_cnt() > 0) {
285 auto acc = get_buf().copy_to_stdvec();
287 for (
u32 i = 0; i < get_obj_cnt(); i++) {
288 if (cd_true(acc, i * nvar, args...)) {
306 template<
class Lambdacd,
class... Args>
309 std::vector<u32> idx_cd{};
310 if (get_obj_cnt() > 0) {
311 auto acc = buf.copy_to_stdvec();
313 for (
u32 i = 0; i < get_obj_cnt(); i++) {
314 if (std::forward<Lambdacd>(cd_true)(acc, i * nvar, std::forward<Args>(args)...)) {
332 template<
class Lambdacd,
class... Args>
334 Lambdacd &&cd_true, Args... args) {
337 if (get_obj_cnt() > 0) {
340 sycl::buffer<u32> mask(get_obj_cnt());
343 const T *acc = buf.get_read_access(depends_list);
347 auto e = q.
submit(depends_list, [&, args...](sycl::handler &cgh) {
348 sycl::accessor acc_mask{mask, cgh, sycl::write_only, sycl::no_init};
349 u32 nvar_field = nvar;
351 shambase::parallel_for(
352 cgh, get_obj_cnt(),
"PatchdataField::get_ids_buf_where", [=](
u32 id) {
353 acc_mask[id] = cd_true(acc,
id * nvar_field, args...);
357 buf.complete_event_state(e);
362 return {std::nullopt, 0};
366 template<
class Lambdacd,
class... Args>
371 auto dev_sched = shamsys::instance::get_compute_scheduler_ptr();
374 auto obj_cnt = get_obj_cnt();
384 [=, nvar_field = nvar](
u32 id,
const T *__restrict acc,
u32 *__restrict acc_mask) {
385 acc_mask[id] = cd_true(acc,
id * nvar_field,
args...);
404 template<
class Lambdacd,
class... Args>
408 auto dev_sched = shamsys::instance::get_compute_scheduler_ptr();
411 auto obj_cnt = get_obj_cnt();
421 [=, nvar_field = nvar](
u32 id,
const T *__restrict acc,
u32 *__restrict acc_mask) {
422 acc_mask[id] = cd_true(acc,
id * nvar_field, args...);
431 template<
class Lambdacd>
432 [[deprecated(
"please use one of the PatchDataField::get_ids_..._where functions instead")]]
433 std::vector<u32> get_elements_with_range(Lambdacd &&cd_true, T vmin, T vmax);
443 template<
class LambdaCd>
444 [[deprecated(
"please use one of the PatchDataField::get_ids_..._where functions instead")]]
447 template<
class Lambdacd>
448 [[deprecated(
"please use one of the PatchDataField::get_ids_..._where functions instead")]]
449 std::unique_ptr<sycl::buffer<u32>> get_elements_with_range_buf(
450 Lambdacd &&cd_true, T vmin, T vmax);
456 template<
class Lambdacd>
457 void check_err_range(Lambdacd &&cd_true, T vmin, T vmax, std::string add_log =
"");
459 void extract_element(
u32 pidx, PatchDataField<T> &to);
462 bool check_field_match(PatchDataField<T> &f2);
464 inline void field_raz() {
465 shamlog_debug_ln(
"PatchDataField",
"raz : ", field_name);
466 override(shambase::VectorProperties<T>::get_zero());
480 inline PatchDataField make_new_from_subset(sycl::buffer<u32> &idxs_buf,
u32 sz) {
481 PatchDataField pfield(field_name, nvar);
488 append_subset_to(idxs_buf, sz, pfield);
492 void gen_mock_data(
u32 obj_cnt, std::mt19937 &eng);
578 T compute_max()
const;
579 T compute_min()
const;
580 T compute_sum()
const;
582 shambase::VecComponent<T> compute_dot_sum();
586 bool has_nan_or_inf();
592 static PatchDataField<T> mock_field(
u64 seed,
u32 obj_cnt, std::string name,
u32 nvar);
593 static PatchDataField<T> mock_field(
594 u64 seed,
u32 obj_cnt, std::string name,
u32 nvar, T vmin, T vmax);
599inline void PatchDataField<T>::resize(
u32 new_obj_cnt) {
600 buf.resize(shambase::narrow_or_throw<u32>(
u64(new_obj_cnt) *
u64(nvar)));
604inline void PatchDataField<T>::reserve(
u32 new_obj_cnt) {
606 u32 add_cnt = shambase::narrow_or_throw<u32>(
u64(new_obj_cnt) *
u64(nvar));
607 buf.reserve(add_cnt);
611inline void PatchDataField<T>::expand(
u32 obj_to_add) {
612 resize(get_obj_cnt() + obj_to_add);
616inline void PatchDataField<T>::shrink(
u32 obj_to_rem) {
618 if (obj_to_rem > get_obj_cnt()) {
621 "impossible to remove more object than there is in the patchdata field");
624 resize(get_obj_cnt() - obj_to_rem);
630 buf.copy_from(f2.buf, obj_cnt * f2.nvar);
636 buf.copy_from(f2, len);
640inline void PatchDataField<T>::override(sycl::buffer<T> &data,
u32 cnt) {
642 buf.copy_from_sycl_buffer(data, cnt);
646inline void PatchDataField<T>::override(std::vector<T> &data,
u32 cnt) {
648 buf.copy_from_stdvec(data, cnt);
652inline void PatchDataField<T>::override(
const T val) {
658template<
class Lambdacd>
659inline std::vector<u32> PatchDataField<T>::get_elements_with_range(
660 Lambdacd &&cd_true, T vmin, T vmax) {
662 std::vector<u32> idxs;
694 auto acc = buf.copy_to_stdvec();
696 for (
u32 i = 0; i < get_val_cnt(); i++) {
697 if (cd_true(acc[i], vmin, vmax)) {
707template<
class Lambdacd>
708inline std::unique_ptr<sycl::buffer<u32>> PatchDataField<T>::get_elements_with_range_buf(
709 Lambdacd &&cd_true, T vmin, T vmax) {
710 std::vector<u32> idxs = get_elements_with_range(std::forward<Lambdacd>(cd_true), vmin, vmax);
718class PatchDataRangeCheckError :
public std::exception {
720 explicit PatchDataRangeCheckError(
const char *message) : msg_(message) {}
722 explicit PatchDataRangeCheckError(
const std::string &message) : msg_(message) {}
724 ~PatchDataRangeCheckError()
noexcept override =
default;
726 [[nodiscard]]
const char *what()
const noexcept override {
return msg_.c_str(); }
733template<
class Lambdacd>
734inline void PatchDataField<T>::check_err_range(
735 Lambdacd &&cd_true, T vmin, T vmax, std::string add_log) {
748 auto acc = buf.copy_to_stdvec();
751 for (
u32 i = 0; i < get_val_cnt(); i++) {
752 if (!cd_true(acc[i], vmin, vmax)) {
Header file describing a Node Instance.
sycl::queue & get_compute_queue(u32 id=0)
std::uint32_t u32
32 bit unsigned integer
std::uint64_t u64
64 bit unsigned integer
shamalgs::SerializeSize serialize_full_byte_size()
give the size usage of serialize_full
void permut_vars(const std::vector< u32 > &permut)
permut the variables of the field according to the permut
shamrock::PatchDataFieldSpan< T, shamrock::dynamic_nvar > get_span_nvar_dynamic()
Returns a shamrock::PatchDataFieldSpan pointing to the current PatchDataField.
shamrock::PatchDataFieldSpan< T, nvar, is_pointer_access > get_span()
Returns a shamrock::PatchDataFieldSpan pointing to the current PatchDataField.
void index_remap(sham::DeviceBuffer< u32 > &index_map, u32 len)
this function remaps the patchdatafield like so val[id] = val[index_map[id]] index map describe : at ...
void remove_ids(const sham::DeviceBuffer< u32 > &indexes, u32 len)
remove the ids from the field
static PatchDataField deserialize_buf(shamalgs::SerializeHelper &serializer, std::string field_name, u32 nvar)
deserialize a field inverse of serialize_buf
void append_subset_to(const std::vector< u32 > &idxs, PatchDataField &pfield)
Copy all objects in idxs to pfield.
shamalgs::SerializeSize serialize_buf_byte_size()
record the size usage of the serialization using serialize_buf
u32 get_val_cnt() const
Get the number of values stored in the field.
static PatchDataField deserialize_full(shamalgs::SerializeHelper &serializer)
deserialize a field inverse of serialize_full
sham::DeviceBuffer< u32 > get_ids_where(Lambdacd &&cd_true, Args... args) const
Same function as.
std::tuple< std::optional< sycl::buffer< u32 > >, u32 > get_ids_buf_where(Lambdacd &&cd_true, Args... args)
Same function as.
std::set< u32 > get_ids_set_where(Lambdacd &&cd_true, Args... args)
Get the ids set where object.
void index_remap_resize(sham::DeviceBuffer< u32 > &index_map, u32 len)
this function remaps the patchdatafield like so val[id] = val[index_map[id]] index map describe : at ...
void serialize_buf(shamalgs::SerializeHelper &serializer)
minimal serialization assuming the user know the layout of the field
std::vector< u32 > get_ids_vec_where(Lambdacd &&cd_true, Args &&...args)
Same function as.
std::tuple< std::optional< sycl::buffer< u32 > >, u32 > get_elements_in_half_open(T vmin, T vmax)
Get the indices of the elements in half open interval.
void serialize_full(shamalgs::SerializeHelper &serializer)
serialize everything in the class
A buffer allocated in USM (Unified Shared Memory).
void resize(size_t new_size, bool keep_data=true)
Resizes the buffer to a given size.
A SYCL queue associated with a device and a context.
sycl::event submit(Fct &&fct)
Submits a kernel to the SYCL queue.
Class to manage a list of SYCL events.
Represents a span of data within a PatchDataField.
This header file contains utility functions related to exception handling in the code.
void kernel_call(sham::DeviceQueue &q, RefIn in, RefOut in_out, u32 n, Functor &&func, SourceLocation &&callsite=SourceLocation{})
Submit a kernel to a SYCL queue.
sycl::buffer< T > vec_to_buf(const std::vector< T > &buf)
Convert a std::vector to a sycl::buffer.
std::tuple< std::optional< sycl::buffer< u32 > >, u32 > stream_compact(sycl::queue &q, sycl::buffer< u32 > &buf_flags, u32 len)
Stream compaction algorithm.
T & get_check_ref(const std::unique_ptr< T > &ptr, SourceLocation loc=SourceLocation())
Takes a std::unique_ptr and returns a reference to the object it holds. It throws a std::runtime_erro...
ExcptTypes make_except_with_loc(std::string message, SourceLocation loc=SourceLocation{})
Create an exception with a message and a location.
void throw_unimplemented(SourceLocation loc=SourceLocation{})
Throw a std::runtime_error saying that the function is unimplemented.
std::vector< std::string_view > args
Executable argument list (mapped from argv).
Utilities for safe type narrowing conversions.
main include file for memory algorithms
void err_ln(std::string module_name, Types... var2)
Prints a log message with multiple arguments followed by a newline.
This file contains the definition for the stacktrace related functionality.
shambase::details::BasicStackEntry StackEntry
Alias for shambase::details::BasicStackEntry.
A class that references multiple buffers or similar objects.