41 std::vector<var_t> fields;
42 std::shared_ptr<PatchDataLayerLayout> pdl_ptr;
50 inline std::shared_ptr<PatchDataLayerLayout> get_layout_ptr()
const {
return pdl_ptr; }
52 inline PatchDataLayer(
const std::shared_ptr<PatchDataLayerLayout> &pdl) : pdl_ptr(pdl) {
60 for (
auto &field_var : other.fields) {
62 field_var.visit([&](
auto &field) {
64 typename std::remove_reference<
decltype(field)>::type::Field_type;
76 : fields(std::move(other.fields)), pdl_ptr(std::move(other.pdl_ptr)) {}
84 fields = std::move(other.fields);
85 pdl_ptr = std::move(other.pdl_ptr);
92 u64 seed,
u32 obj_cnt,
const std::shared_ptr<PatchDataLayerLayout> &pdl);
94 template<
class Functor>
95 inline void for_each_field_any(Functor &&func) {
96 for (
auto &f : fields) {
97 f.visit([&](
auto &arg) {
103 template<
class Functor>
104 inline void for_each_field_any(Functor &&func)
const {
105 for (
auto &f : fields) {
106 f.visit([&](
const auto &arg) {
113 inline PatchDataLayer(
const std::shared_ptr<PatchDataLayerLayout> &pdl, Func &&fct_init)
121 inline PatchDataLayer duplicate() {
122 const PatchDataLayer ¤t = *
this;
123 return PatchDataLayer(current);
126 inline std::unique_ptr<PatchDataLayer> duplicate_to_ptr() {
127 const PatchDataLayer ¤t = *
this;
128 return std::make_unique<PatchDataLayer>(current);
141 void keep_ids(sycl::buffer<u32> &index_map,
u32 len);
143 void insert_elements(
const PatchDataLayer &pdat);
156 void resize(
u32 new_obj_cnt);
158 void reserve(
u32 new_obj_cnt);
160 void expand(
u32 obj_cnt);
200 template<
class Tvecbox>
201 void split_patchdata(
202 std::array<std::reference_wrapper<PatchDataLayer>, 8> pdats,
203 std::array<Tvecbox, 8> min_box,
204 std::array<Tvecbox, 8> max_box);
206 void append_subset_to(
const std::vector<u32> &idxs, PatchDataLayer &pdat);
207 void append_subset_to(sycl::buffer<u32> &idxs_buf,
u32 sz, PatchDataLayer &pdat);
208 void append_subset_to(
211 inline u32 get_obj_cnt()
const {
213 bool is_empty = fields.empty();
216 return fields[0].visit_return([](
const auto &field) {
217 return field.get_obj_cnt();
222 "this PatchDataLayer does not contain any fields");
225 inline u64 memsize() {
228 for (
auto &field_var : fields) {
230 field_var.visit([&](
auto &field) {
231 sum += field.memsize();
238 inline bool is_empty() {
return get_obj_cnt() == 0; }
240 void synchronize_buf() {
241 for (
auto &field_var : fields) {
242 field_var.visit([&](
auto &field) {
243 field.synchronize_buf();
248 void overwrite(PatchDataLayer &pdat,
u32 obj_cnt);
251 bool check_field_type(
u32 idx) {
252 var_t &tmp = fields.at(idx);
266 var_t &tmp = fields.at(idx);
275 "the request id is not of correct type\n"
276 " current map is : \n"
277 + pdl().get_description_str()
280 + std::to_string(idx));
286 const var_t &tmp = fields.at(idx);
295 "the request id is not of correct type\n"
296 " current map is : \n"
297 + pdl().get_description_str()
300 + std::to_string(idx));
305 return get_field<T>(pdl().get_field_idx<T>(field_name));
310 return get_field<T>(pdl().get_field_idx<T>(field_name));
316 var_t &tmp = fields.at(idx);
321 return pval->get_buf();
325 "the request id is not of correct type\n"
326 " current map is : \n"
327 + pdl().get_description_str()
330 + std::to_string(idx));
339 template<
class T, u32 nvar>
341 return get_field<T>(idx).template get_span<nvar>();
353 return get_field<T>(idx).get_span_nvar_dynamic();
358 get_field_pointer_span(
u32 idx) {
359 return get_field<T>(idx).get_pointer_span();
367 u32 cnt = get_obj_cnt();
368 for (
auto &field_var : fields) {
369 field_var.visit([&](
auto &field) {
370 if (field.get_obj_cnt() != cnt) {
372 "mismatch in obj cnt");
384 template<
class T,
class Functor>
385 inline void for_each_field(Functor &&func) {
386 for (
auto &f : fields) {
395 inline friend bool operator==(PatchDataLayer &p1, PatchDataLayer &p2) {
398 if (p1.fields.size() != p2.fields.size()) {
402 for (
u32 idx = 0; idx < p1.fields.size(); idx++) {
404 bool ret = std::visit(
405 [&](
auto &pf1,
auto &pf2) ->
bool {
406 using t1 =
typename std::remove_reference<
decltype(pf1)>::type::Field_type;
407 using t2 =
typename std::remove_reference<
decltype(pf2)>::type::Field_type;
409 if constexpr (std::is_same<t1, t2>::value) {
410 return pf1.check_field_match(pf2);
415 p1.fields[idx].value,
416 p2.fields[idx].value);
418 check = check && ret;
428 static PatchDataLayer deserialize_buf(
430 const std::shared_ptr<PatchDataLayerLayout> &pdl);
439 for (
auto &field_var : fields) {
440 field_var.visit([&](
auto &field) {
441 if (field.has_nan()) {
453 for (
auto &field_var : fields) {
454 field_var.visit([&](
auto &field) {
455 if (field.has_inf()) {
462 bool has_nan_or_inf() {
467 for (
auto &field_var : fields) {
468 field_var.visit([&](
auto &field) {
469 if (field.has_nan_or_inf()) {
487 u32 len = vec.size();
489 sycl::buffer<T> buf(vec.data(), len);
490 f.override(buf, len);
508 auto appender = [&](
auto &field) {
509 if (field.get_name() == key) {
511 shamlog_debug_ln(
"PyShamrockCTX",
"appending field", key);
513 if (!field.is_empty()) {
514 auto acc = field.get_buf().copy_to_stdvec();
515 u32 len = field.get_val_cnt();
517 for (
u32 i = 0; i < len; i++) {
518 vec.push_back(acc[i]);
524 for_each_field<T>([&](
auto &field) {