49 static constexpr bool isprimitive = std::is_same<T, f32>::value || std::is_same<T, f64>::value
50 || std::is_same<T, u32>::value
51 || std::is_same<T, u64>::value;
54 static constexpr bool is_in_type_list =
55 #define X(args) std::is_same<T, args>::value ||
56 XMAC_LIST_ENABLED_FIELD
false
62 "PatchDataField must be one of those types : "
64 #define X(args) #args " "
65 XMAC_LIST_ENABLED_FIELD
70 template<
bool B,
class Tb =
void>
71 using enable_if_t =
typename std::enable_if<B, Tb>;
73 using EnableIfPrimitive = enable_if_t<isprimitive>;
75 using EnableIfVec = enable_if_t<is_in_type_list && (!isprimitive)>;
83 std::string field_name;
87 inline void check_nvar()
const {
99 : buf(std::move(other.buf)), field_name(std::move(other.field_name)),
100 nvar(std::move(other.nvar)) {}
103 buf = std::move(other.buf);
104 field_name = std::move(other.field_name);
105 nvar = std::move(other.nvar);
112 using Field_type = T;
115 : field_name(std::move(name)), nvar(nvar),
116 buf(0, shamsys::instance::get_compute_scheduler_ptr()) {
121 : field_name(std::move(name)), nvar(nvar),
122 buf(obj_cnt * nvar, shamsys::instance::get_compute_scheduler_ptr()) {
127 : field_name(other.field_name), nvar(other.nvar), buf(other.buf.copy()) {
137 : field_name(name), nvar(nvar), buf(std::forward<sycl::buffer<T>>(moved_buf),
139 shamsys::instance::get_compute_scheduler_ptr()) {
157 ret.field_name = new_name;
161 inline std::unique_ptr<PatchDataField> duplicate_to_ptr()
const {
163 return std::make_unique<PatchDataField>(current);
169 [[nodiscard]]
inline bool is_empty()
const {
return get_obj_cnt() == 0; }
171 [[nodiscard]]
inline u64 memsize()
const {
return buf.get_mem_usage(); }
173 [[nodiscard]]
inline const u32 &get_nvar()
const {
return nvar; }
175 [[nodiscard]]
inline u32 get_obj_cnt()
const {
176 size_t sz = buf.get_size();
177 if (sz % nvar != 0) {
179 "the size of the buffer ({}) is not a multiple of the number of variables ({})",
199 [[nodiscard]]
inline const std::string &get_name()
const {
return field_name; }
202 void resize(
u32 new_obj_cnt);
204 void reserve(
u32 new_obj_cnt);
206 void expand(
u32 obj_to_add);
208 void shrink(
u32 obj_to_rem);
210 void insert_element(T v);
212 void apply_offset(T off);
220 void override(sycl::buffer<T> &data,
u32 cnt);
222 void override(std::vector<T> &data,
u32 cnt);
224 void override(
const T val);
226 inline void synchronize_buf() { buf.synchronize(); }
228 inline std::vector<T> copy_to_stdvec() {
229 auto tmp = buf.copy_to_stdvec();
244 template<u32 nvar,
bool is_po
inter_access = shamrock::access_t_span>
258 return get_span<shamrock::dynamic_nvar, shamrock::access_t_span>();
264 return get_span<shamrock::dynamic_nvar, shamrock::access_t_pointer>();
280 template<
class Lambdacd,
class... Args>
283 std::set<u32> idx_cd{};
284 if (get_obj_cnt() > 0) {
285 auto acc = get_buf().copy_to_stdvec();
287 for (
u32 i = 0; i < get_obj_cnt(); i++) {
288 if (cd_true(acc, i * nvar, args...)) {
306 template<
class Lambdacd,
class... Args>
309 std::vector<u32> idx_cd{};
310 if (get_obj_cnt() > 0) {
311 auto acc = buf.copy_to_stdvec();
313 for (
u32 i = 0; i < get_obj_cnt(); i++) {
314 if (std::forward<Lambdacd>(cd_true)(acc, i * nvar, std::forward<Args>(args)...)) {
332 template<
class Lambdacd,
class... Args>
334 Lambdacd &&cd_true, Args... args) {
337 if (get_obj_cnt() > 0) {
340 sycl::buffer<u32> mask(get_obj_cnt());
343 const T *acc = buf.get_read_access(depends_list);
347 auto e = q.
submit(depends_list, [&, args...](sycl::handler &cgh) {
348 sycl::accessor acc_mask{mask, cgh, sycl::write_only, sycl::no_init};
349 u32 nvar_field = nvar;
351 shambase::parallel_for(
352 cgh, get_obj_cnt(),
"PatchdataField::get_ids_buf_where", [=](
u32 id) {
353 acc_mask[id] = cd_true(acc,
id * nvar_field, args...);
357 buf.complete_event_state(e);
360 shamsys::instance::get_compute_queue(), mask, get_obj_cnt());
362 return {std::nullopt, 0};
366 template<
class Lambdacd,
class... Args>
371 auto dev_sched = shamsys::instance::get_compute_scheduler_ptr();
374 auto obj_cnt = get_obj_cnt();
384 [=, nvar_field = nvar](
u32 id,
const T *__restrict acc,
u32 *__restrict acc_mask) {
385 acc_mask[id] = cd_true(acc,
id * nvar_field,
args...);
404 template<
class Lambdacd,
class... Args>
408 auto dev_sched = shamsys::instance::get_compute_scheduler_ptr();
411 auto obj_cnt = get_obj_cnt();
421 [=, nvar_field = nvar](
u32 id,
const T *__restrict acc,
u32 *__restrict acc_mask) {
422 acc_mask[id] = cd_true(acc,
id * nvar_field, args...);
431 template<
class Lambdacd>
432 [[deprecated(
"please use one of the PatchDataField::get_ids_..._where functions instead")]]
433 std::vector<u32> get_elements_with_range(Lambdacd &&cd_true, T vmin, T vmax);
443 template<
class LambdaCd>
444 [[deprecated(
"please use one of the PatchDataField::get_ids_..._where functions instead")]]
447 template<
class Lambdacd>
448 [[deprecated(
"please use one of the PatchDataField::get_ids_..._where functions instead")]]
449 std::unique_ptr<sycl::buffer<u32>> get_elements_with_range_buf(
450 Lambdacd &&cd_true, T vmin, T vmax);
456 template<
class Lambdacd>
457 void check_err_range(Lambdacd &&cd_true, T vmin, T vmax, std::string add_log =
"");
464 inline void field_raz() {
465 shamlog_debug_ln(
"PatchDataField",
"raz : ", field_name);
475 void append_subset_to(
const std::vector<u32> &idxs,
PatchDataField &pfield);
476 void append_subset_to(sycl::buffer<u32> &idxs_buf,
u32 sz,
PatchDataField &pfield);
477 void append_subset_to(
480 inline PatchDataField make_new_from_subset(sycl::buffer<u32> &idxs_buf,
u32 sz) {
482 append_subset_to(idxs_buf, sz, pfield);
492 void gen_mock_data(
u32 obj_cnt, std::mt19937 &eng);
507 void permut_vars(
const std::vector<u32> &permut);
578 T compute_max()
const;
579 T compute_min()
const;
580 T compute_sum()
const;
582 shambase::VecComponent<T> compute_dot_sum();
586 bool has_nan_or_inf();
594 u64 seed,
u32 obj_cnt, std::string name,
u32 nvar, T vmin, T vmax);
607 buf.reserve(add_cnt);
612 resize(get_obj_cnt() + obj_to_add);
618 if (obj_to_rem > get_obj_cnt()) {
621 "impossible to remove more object than there is in the patchdata field");
624 resize(get_obj_cnt() - obj_to_rem);
630 buf.copy_from(f2.buf, obj_cnt * f2.nvar);
636 buf.copy_from(f2, len);
642 buf.copy_from_sycl_buffer(data, cnt);
648 buf.copy_from_stdvec(data, cnt);
658template<
class Lambdacd>
660 Lambdacd &&cd_true, T vmin, T vmax) {
662 std::vector<u32> idxs;
694 auto acc = buf.copy_to_stdvec();
696 for (
u32 i = 0; i < get_val_cnt(); i++) {
697 if (cd_true(acc[i], vmin, vmax)) {
707template<
class Lambdacd>
709 Lambdacd &&cd_true, T vmin, T vmax) {
710 std::vector<u32> idxs = get_elements_with_range(std::forward<Lambdacd>(cd_true), vmin, vmax);
726 [[nodiscard]]
const char *what()
const noexcept override {
return msg_.c_str(); }
733template<
class Lambdacd>
735 Lambdacd &&cd_true, T vmin, T vmax, std::string add_log) {
748 auto acc = buf.copy_to_stdvec();
751 for (
u32 i = 0; i < get_val_cnt(); i++) {
752 if (!cd_true(acc[i], vmin, vmax)) {
767 logger::err_ln(
"PatchDataField",
"...");
775 logger::err_ln(
"PatchDataField",
"additional infos :", add_log);
Header file describing a Node Instance.
std::uint32_t u32
32 bit unsigned integer
std::uint64_t u64
64 bit unsigned integer
shamrock::PatchDataFieldSpan< T, shamrock::dynamic_nvar > get_span_nvar_dynamic()
Returns a shamrock::PatchDataFieldSpan pointing to the current PatchDataField.
shamrock::PatchDataFieldSpan< T, nvar, is_pointer_access > get_span()
Returns a shamrock::PatchDataFieldSpan pointing to the current PatchDataField.
u32 get_val_cnt() const
Get the number of values stored in the field.
sham::DeviceBuffer< u32 > get_ids_where(Lambdacd &&cd_true, Args... args) const
Same function as.
std::tuple< std::optional< sycl::buffer< u32 > >, u32 > get_ids_buf_where(Lambdacd &&cd_true, Args... args)
Same function as.
std::set< u32 > get_ids_set_where(Lambdacd &&cd_true, Args... args)
Get the ids set where object.
std::vector< u32 > get_ids_vec_where(Lambdacd &&cd_true, Args &&...args)
Same function as.
std::tuple< std::optional< sycl::buffer< u32 > >, u32 > get_elements_in_half_open(T vmin, T vmax)
Get the indices of the elements in half open interval.
A buffer allocated in USM (Unified Shared Memory)
void resize(size_t new_size, bool keep_data=true)
Resizes the buffer to a given size.
A SYCL queue associated with a device and a context.
sycl::event submit(Fct &&fct)
Submits a kernel to the SYCL queue.
DeviceQueue & get_queue(u32 id=0)
Get a reference to a DeviceQueue.
Class to manage a list of SYCL events.
Represents a span of data within a PatchDataField.
This header file contains utility functions related to exception handling in the code.
void kernel_call(sham::DeviceQueue &q, RefIn in, RefOut in_out, u32 n, Functor &&func, SourceLocation &&callsite=SourceLocation{})
Submit a kernel to a SYCL queue.
sycl::buffer< T > index_remap(sycl::queue &q, sycl::buffer< T > &source_buf, sycl::buffer< u32 > &index_map, u32 len)
remap a buffer according to a given index map result[i] = result[index_map[i]]
sycl::buffer< T > vec_to_buf(const std::vector< T > &buf)
Convert a std::vector to a sycl::buffer
std::tuple< std::optional< sycl::buffer< u32 > >, u32 > stream_compact(sycl::queue &q, sycl::buffer< u32 > &buf_flags, u32 len)
Stream compaction algorithm.
void append_subset_to(const sham::DeviceBuffer< T > &buf, const sham::DeviceBuffer< u32 > &idxs_buf, u32 nvar, sham::DeviceBuffer< T > &buf_other, u32 start_enque)
Appends a subset of elements from one buffer to another.
void throw_with_loc(std::string message, SourceLocation loc=SourceLocation{})
Throw an exception and append the source location to it.
T & get_check_ref(const std::unique_ptr< T > &ptr, SourceLocation loc=SourceLocation())
Takes a std::unique_ptr and returns a reference to the object it holds. It throws a std::runtime_erro...
void throw_unimplemented(SourceLocation loc=SourceLocation{})
Throw a std::runtime_error saying that the function is unimplemented.
std::vector< std::string_view > args
Executable argument list (mapped from argv)
Utilities for safe type narrowing conversions.
main include file for memory algorithms
This file contains the definition for the stacktrace related functionality.
A class that references multiple buffers or similar objects.