35 class PatchDataLayer {
41 std::vector<var_t> fields;
42 std::shared_ptr<PatchDataLayerLayout> pdl_ptr;
44 inline var_t &get_field_variant(
u32 idx) {
45 if (idx >= fields.size()) {
47 "the requested field index is out of bounds\n"
48 " current map is : \n"
49 + pdl().get_description_str()
52 + std::to_string(idx));
57 inline const var_t &get_field_variant(
u32 idx)
const {
58 if (idx >= fields.size()) {
60 "the requested field index is out of bounds\n"
61 " current map is : \n"
62 + pdl().get_description_str()
65 + std::to_string(idx));
71 using field_variant_t = var_t;
76 inline std::shared_ptr<PatchDataLayerLayout> get_layout_ptr()
const {
return pdl_ptr; }
78 inline PatchDataLayer(
const std::shared_ptr<PatchDataLayerLayout> &pdl) : pdl_ptr(pdl) {
82 inline PatchDataLayer(
const PatchDataLayer &other) : pdl_ptr(other.get_layout_ptr()) {
86 for (
auto &field_var : other.fields) {
88 field_var.visit([&](
auto &field) {
90 typename std::remove_reference<
decltype(field)>::type::Field_type;
102 : fields(std::move(other.fields)), pdl_ptr(std::move(other.pdl_ptr)) {}
109 inline PatchDataLayer &
operator=(PatchDataLayer &&other)
noexcept {
110 fields = std::move(other.fields);
111 pdl_ptr = std::move(other.pdl_ptr);
118 u64 seed,
u32 obj_cnt,
const std::shared_ptr<PatchDataLayerLayout> &pdl);
120 template<
class Functor>
121 inline void for_each_field_any(Functor &&func) {
122 for (
auto &f : fields) {
123 f.visit([&](
auto &arg) {
129 template<
class Functor>
130 inline void for_each_field_any(Functor &&func)
const {
131 for (
auto &f : fields) {
132 f.visit([&](
const auto &arg) {
139 inline PatchDataLayer(
const std::shared_ptr<PatchDataLayerLayout> &pdl, Func &&fct_init)
147 inline PatchDataLayer duplicate() {
148 const PatchDataLayer ¤t = *
this;
149 return PatchDataLayer(current);
152 inline std::unique_ptr<PatchDataLayer> duplicate_to_ptr() {
153 const PatchDataLayer ¤t = *
this;
154 return std::make_unique<PatchDataLayer>(current);
165 void extract_elements(
const sham::DeviceBuffer<u32> &idxs, PatchDataLayer &out_pdat);
167 void keep_ids(sycl::buffer<u32> &index_map,
u32 len);
169 void insert_elements(
const PatchDataLayer &pdat);
182 void resize(
u32 new_obj_cnt);
184 void reserve(
u32 new_obj_cnt);
186 void expand(
u32 obj_cnt);
212 void keep_ids(sham::DeviceBuffer<u32> &index_map,
u32 len);
215 void remove_ids(
const sham::DeviceBuffer<u32> &indexes,
u32 len);
226 template<
class Tvecbox>
227 void split_patchdata(
228 std::array<std::reference_wrapper<PatchDataLayer>, 8> pdats,
229 std::array<Tvecbox, 8> min_box,
230 std::array<Tvecbox, 8> max_box);
232 void append_subset_to(
const std::vector<u32> &idxs, PatchDataLayer &pdat);
233 void append_subset_to(sycl::buffer<u32> &idxs_buf,
u32 sz, PatchDataLayer &pdat);
234 void append_subset_to(
235 const sham::DeviceBuffer<u32> &idxs_buf,
u32 sz, PatchDataLayer &pdat)
const;
237 inline u32 get_obj_cnt()
const {
239 bool is_empty = fields.empty();
242 return fields[0].visit_return([](
const auto &field) {
243 return field.get_obj_cnt();
248 "this PatchDataLayer does not contain any fields");
251 inline u64 memsize() {
254 for (
auto &field_var : fields) {
256 field_var.visit([&](
auto &field) {
257 sum += field.memsize();
264 inline bool is_empty() {
return get_obj_cnt() == 0; }
266 void synchronize_buf() {
267 for (
auto &field_var : fields) {
268 field_var.visit([&](
auto &field) {
269 field.synchronize_buf();
274 void overwrite(PatchDataLayer &pdat,
u32 obj_cnt);
277 bool check_field_type(
u32 idx) {
278 var_t &tmp = get_field_variant(idx);
280 PatchDataField<T> *pval = std::get_if<PatchDataField<T>>(&tmp.value);
290 PatchDataField<T> &get_field(
u32 idx) {
292 var_t &tmp = get_field_variant(idx);
294 PatchDataField<T> *pval = std::get_if<PatchDataField<T>>(&tmp.value);
301 "the request id is not of correct type\n"
302 " current map is : \n"
303 + pdl().get_description_str()
306 + std::to_string(idx));
310 const PatchDataField<T> &get_field(
u32 idx)
const {
312 const var_t &tmp = get_field_variant(idx);
314 const PatchDataField<T> *pval = std::get_if<PatchDataField<T>>(&tmp.value);
321 "the request id is not of correct type\n"
322 " current map is : \n"
323 + pdl().get_description_str()
326 + std::to_string(idx));
330 PatchDataField<T> &get_field(
const std::string &field_name) {
331 return get_field<T>(pdl().get_field_idx<T>(field_name));
335 const PatchDataField<T> &get_field(
const std::string &field_name)
const {
336 return get_field<T>(pdl().get_field_idx<T>(field_name));
340 sham::DeviceBuffer<T> &get_field_buf_ref(
u32 idx) {
342 var_t &tmp = get_field_variant(idx);
344 PatchDataField<T> *pval = std::get_if<PatchDataField<T>>(&tmp.value);
347 return pval->get_buf();
351 "the request id is not of correct type\n"
352 " current map is : \n"
353 + pdl().get_description_str()
356 + std::to_string(idx));
365 template<
class T, u32 nvar>
367 return get_field<T>(idx).template get_span<nvar>();
379 return get_field<T>(idx).get_span_nvar_dynamic();
384 get_field_pointer_span(
u32 idx) {
385 return get_field<T>(idx).get_pointer_span();
393 u32 cnt = get_obj_cnt();
394 for (
auto &field_var : fields) {
395 field_var.visit([&](
auto &field) {
396 if (field.get_obj_cnt() != cnt) {
398 "mismatch in obj cnt");
410 template<
class T,
class Functor>
411 inline void for_each_field(Functor &&func) {
412 for (
auto &f : fields) {
421 inline friend bool operator==(PatchDataLayer &p1, PatchDataLayer &p2) {
424 if (p1.fields.size() != p2.fields.size()) {
428 for (
u32 idx = 0; idx < p1.fields.size(); idx++) {
430 bool ret = std::visit(
431 [&](
auto &pf1,
auto &pf2) ->
bool {
432 using t1 =
typename std::remove_reference<
decltype(pf1)>::type::Field_type;
433 using t2 =
typename std::remove_reference<
decltype(pf2)>::type::Field_type;
435 if constexpr (std::is_same<t1, t2>::value) {
436 return pf1.check_field_match(pf2);
441 p1.fields[idx].value,
442 p2.fields[idx].value);
444 check = check && ret;
450 void serialize_buf(shamalgs::SerializeHelper &serializer);
452 shamalgs::SerializeSize serialize_buf_byte_size();
454 static PatchDataLayer deserialize_buf(
455 shamalgs::SerializeHelper &serializer,
456 const std::shared_ptr<PatchDataLayerLayout> &pdl);
465 for (
auto &field_var : fields) {
466 field_var.visit([&](
auto &field) {
467 if (field.has_nan()) {
479 for (
auto &field_var : fields) {
480 field_var.visit([&](
auto &field) {
481 if (field.has_inf()) {
488 bool has_nan_or_inf() {
493 for (
auto &field_var : fields) {
494 field_var.visit([&](
auto &field) {
495 if (field.has_nan_or_inf()) {
513 u32 len = vec.size();
515 sycl::buffer<T> buf(vec.data(), len);
516 f.override(buf, len);
534 auto appender = [&](
auto &field) {
535 if (field.get_name() == key) {
537 shamlog_debug_ln(
"PyShamrockCTX",
"appending field", key);
539 if (!field.is_empty()) {
540 auto acc = field.get_buf().copy_to_stdvec();
541 u32 len = field.get_val_cnt();
543 for (
u32 i = 0; i < len; i++) {
544 vec.push_back(acc[i]);
550 for_each_field<T>([&](
auto &field) {