Shamrock 2025.10.0
Astrophysical Code
Loading...
Searching...
No Matches
PatchDataField.cpp
Go to the documentation of this file.
1// -------------------------------------------------------//
2//
3// SHAMROCK code for hydrodynamics
4// Copyright (c) 2021-2026 Timothée David--Cléris <tim.shamrock@proton.me>
5// SPDX-License-Identifier: CeCILL Free Software License Agreement v2.1
6// Shamrock is licensed under the CeCILL 2.1 License, see LICENSE for more information
7//
8// -------------------------------------------------------//
9
18#include "shambase/memory.hpp"
20#include "shambase/string.hpp"
29#include "shamalgs/random.hpp"
32#include "shambackends/vec.hpp"
37#include <algorithm>
38#include <memory>
39#include <numeric>
40#include <vector>
41
42template<class T>
44
45template<class T>
47
48 auto fast_extract_ptr = [](u32 idx, u32 length, auto cnt) {
49 T end_ = cnt[length - 1];
50 T extr = cnt[idx];
51
52 cnt[idx] = end_;
53
54 return extr;
55 };
56
57 auto sub_extract
58 = [fast_extract_ptr](u32 pidx, PatchDataField<T> &from, PatchDataField<T> &to) {
59 const u32 nvar = from.get_nvar();
60 const u32 idx_val = pidx * nvar;
61 const u32 idx_out_val = to.get_val_cnt();
62
63 u32 from_sz = from.get_val_cnt();
64
65 to.expand(1);
66
67 {
68
69 auto &buf_to = to.get_buf();
70 auto &buf_from = from.get_buf();
71
72 sham::EventList depends_list;
73 T *acc_to = buf_to.get_write_access(depends_list);
74 T *acc_from = buf_from.get_write_access(depends_list);
75
76 sham::DeviceQueue &q = shamsys::instance::get_compute_scheduler().get_queue();
77
78 auto e = q.submit(depends_list, [&](sycl::handler &cgh) {
79 const u32 nvar_loc = nvar;
80
81 cgh.single_task<Kernel_Extract_element<T>>([=]() {
82 for (u32 i = nvar_loc - 1; i < nvar_loc; i--) {
83 acc_to[idx_out_val + i]
84 = (fast_extract_ptr(idx_val + i, from_sz, acc_from));
85 }
86 });
87 });
88
89 buf_to.complete_event_state(e);
90 buf_from.complete_event_state(e);
91 }
92
93 from.shrink(1);
94 };
95
96 sub_extract(pidx, *this, to);
97}
98
99template<class T>
102 if (&to == this) {
104 "source and destination for extract_elements cannot be the same");
105 }
106 StackEntry stack_loc{};
107 append_subset_to(idxs, idxs.get_size(), to);
108 remove_ids(idxs, idxs.get_size());
109}
110
111template<class T>
113 bool match = true;
114
115 match = match && (field_name == f2.field_name);
116 match = match && (nvar == f2.nvar);
117
118 auto dev_sched = shamsys::instance::get_compute_scheduler_ptr();
119
120 // it will also check that the size is the same
121 match = match && shamalgs::primitives::equals(dev_sched, buf, f2.buf);
122
123 return match;
124}
125
126template<class T>
128 const sham::DeviceBuffer<u32> &idxs_buf, u32 sz, PatchDataField &pfield) const {
129
130 if (pfield.nvar != nvar)
132 "field must be similar for extraction");
133
134 if (static_cast<size_t>(sz) != idxs_buf.get_size())
136 "the size of the idxs buffer does not match the size of the subset");
137
138 if (sz > 0) {
139 const u32 start_enque = pfield.get_val_cnt();
140 const u32 nvar = get_nvar();
141 pfield.expand(sz);
142
143 shamalgs::primitives::append_subset_to(buf, idxs_buf, nvar, pfield.get_buf(), start_enque);
144 }
145}
146
147template<class T>
149 sycl::buffer<u32> &idxs_buf, u32 sz, PatchDataField &pfield) {
150
151 if (sz > 0) {
152 sham::DeviceBuffer<u32> buffer(sz, shamsys::instance::get_compute_scheduler_ptr());
153 buffer.copy_from_sycl_buffer(idxs_buf);
154 append_subset_to(buffer, sz, pfield);
155 }
156}
157
158template<class T>
159void PatchDataField<T>::append_subset_to(const std::vector<u32> &idxs, PatchDataField &pfield) {
160
161 u32 sz = shambase::narrow_or_throw<u32>(idxs.size());
162
163 if (sz > 0) {
164 sham::DeviceBuffer<u32> idxs_buf(sz, shamsys::instance::get_compute_scheduler_ptr());
165 idxs_buf.copy_from_stdvec(idxs);
166 append_subset_to(idxs_buf, sz, pfield);
167 }
168}
169
170template<class T>
172
173template<class T>
175 if (nvar != 1) {
177 }
178 u32 ins_pos = get_val_cnt();
179 expand(1);
180
181 auto sptr = shamsys::instance::get_compute_scheduler_ptr();
182 auto &q = sptr->get_queue();
183
184 sham::EventList depends_list;
185 T *acc = get_buf().get_write_access(depends_list);
186
187 auto e = q.submit(depends_list, [&](sycl::handler &cgh) {
188 auto id_ins = ins_pos;
189 auto val = v;
190
191 cgh.single_task<PdatField_insert_element<T>>([=]() {
192 acc[id_ins] = val;
193 });
194 });
195
196 get_buf().complete_event_state(e);
197}
198
199template<class T>
201
202template<class T>
204
205 if (get_obj_cnt() > 0) {
206
207 auto sptr = shamsys::instance::get_compute_scheduler_ptr();
208 auto &q = sptr->get_queue();
209
210 sham::EventList depends_list;
211 T *acc = get_buf().get_write_access(depends_list);
212
213 auto e = q.submit(depends_list, [&](sycl::handler &cgh) {
214 auto val = off;
215
216 cgh.parallel_for<PdatField_apply_offset<T>>(
217 sycl::range<1>{get_val_cnt()}, [=](sycl::id<1> idx) {
218 acc[idx] += val;
219 });
220 });
221 get_buf().complete_event_state(e);
222 }
223}
224
225template<class T>
227
228template<class T>
230 get_buf().append(f2.get_buf());
231}
232
233template<class T>
235
236 if (!buf.is_empty()) {
237
238 auto sched_ptr = shamsys::instance::get_compute_scheduler_ptr();
239
240 auto get_new_buf = [&]() {
241 if (nvar == 1) {
242 return shamalgs::algorithm::index_remap(sched_ptr, buf, index_map, len);
243 } else {
244 return shamalgs::algorithm::index_remap_nvar(sched_ptr, buf, index_map, len, nvar);
245 }
246 };
247
248 buf = get_new_buf();
249 }
250}
251
252template<class T>
254
255 if (len != get_obj_cnt()) {
257 "the match of the new index map does not match with the patchdatafield obj count: {} "
258 "!= {}",
259 len,
260 get_obj_cnt()));
261 }
262
263 index_remap_resize(index_map, len);
264}
265
266template<class T>
267void PatchDataField<T>::permut_vars(const std::vector<u32> &permut) {
268 if (permut.size() != get_nvar()) {
270 "the number of permut is not equal to the patchdatafield nvar: {} != {}",
271 permut.size(),
272 get_nvar()));
273 }
274
275 auto dev_sched = shamsys::instance::get_compute_scheduler_ptr();
276 auto &q = dev_sched->get_queue();
277
278 sham::DeviceBuffer<u32> permut_buf(
279 permut.size(), shamsys::instance::get_compute_scheduler_ptr());
280 permut_buf.copy_from_stdvec(permut);
281
282 sham::DeviceBuffer<T> copy = buf.copy();
283
285 q,
286 sham::MultiRef{copy, permut_buf},
287 sham::MultiRef{buf},
288 get_val_cnt(),
289 [nvar = nvar](u32 i, const T *src, const u32 *permut, T *dst) {
290 u32 obj_id = i / nvar;
291 u32 var_id = i % nvar;
292
293 u32 new_var_id = permut[var_id];
294
295 dst[obj_id * nvar + new_var_id] = src[i];
296 });
297}
298
299template<class T>
301
302 auto dev_sched = shamsys::instance::get_compute_scheduler_ptr();
303 auto &q = dev_sched->get_queue();
304
305 if (len > get_obj_cnt()) {
307 "the number of ids to remove is greater than the patchdatafield obj count: {} > {}",
308 len,
309 get_obj_cnt()));
310 }
311
312 if (len == 0) {
313 return;
314 }
315
316 auto nobj = get_obj_cnt();
317 auto remaining = nobj - len;
318
319 sham::DeviceBuffer<u32> keep_flag(get_obj_cnt(), dev_sched);
320 keep_flag.fill(1);
321
323 q,
324 sham::MultiRef{ids_to_rem},
325 sham::MultiRef{keep_flag},
326 len,
327 [](u32 i, const u32 *idx, u32 *idx_map) {
328 idx_map[idx[i]] = 0;
329 });
330
331 auto keep_ids = shamalgs::numeric::stream_compact(dev_sched, keep_flag, nobj);
332
333 if (keep_ids.get_size() != remaining) {
334
335 // post mortem analysis
336
337 std::vector<u32> ids_to_rem_vec = ids_to_rem.copy_to_stdvec_idx_range(0, len);
338
339 std::sort(ids_to_rem_vec.begin(), ids_to_rem_vec.end());
340
341 bool has_duplicates = false;
342
343 // Adjacent elements in ids_to_rem_vec should be different
344 has_duplicates = std::adjacent_find(ids_to_rem_vec.begin(), ids_to_rem_vec.end())
345 != ids_to_rem_vec.end();
346
347 std::vector<u32> keep_flags_vec = keep_flag.copy_to_stdvec();
348
349 // compute keep flags sum
350 u32 keep_flags_sum = std::accumulate(keep_flags_vec.begin(), keep_flags_vec.end(), u32(0));
351
352 std::string log = shambase::format(
353 "the number of remaining ids {} is different from the expected {}",
354 keep_ids.get_size(),
355 remaining);
356
357 log += "\n\nAdditional information:\n";
358 if (has_duplicates) {
359 log += " ids_to_rem has duplicates = true\n";
360 } else {
361 log += " ids_to_rem has duplicates = false\n";
362 }
363 log += shambase::format(" keep flags sum = {}\n", keep_flags_sum);
364
366 }
367
368 index_remap_resize(keep_ids, remaining);
369}
370
371template<class T>
373 u64 seed, u32 obj_cnt, std::string name, u32 nvar, T vmin, T vmax) {
374
375 std::vector<T> buf = shamalgs::primitives::mock_vector<T>(seed, obj_cnt * nvar, vmin, vmax);
376 PatchDataField<T> ret(name, nvar, obj_cnt);
377 ret.get_buf().copy_from_stdvec(buf);
378
379 return ret;
380}
381
382template<class T>
383PatchDataField<T> PatchDataField<T>::mock_field(u64 seed, u32 obj_cnt, std::string name, u32 nvar) {
386 seed, obj_cnt, name, nvar, Prop::get_min(), Prop::get_max());
387}
388
389template<class T>
391 StackEntry stack_loc{false};
392 u32 obj_cnt = get_obj_cnt();
393 serializer.write(obj_cnt);
394 shamlog_debug_sycl_ln("PatchDataField", "serialize patchdatafield len=", obj_cnt);
395 if (obj_cnt > 0) {
396 serializer.write_buf(buf, get_val_cnt());
397 }
398}
399
400template<class T>
402 shamalgs::SerializeHelper &serializer, std::string field_name, u32 nvar) {
403 StackEntry stack_loc{false};
404 u32 cnt;
405 serializer.load(cnt);
406 shamlog_debug_sycl_ln("PatchDataField", "deserialize patchdatafield len=", cnt);
407
408 if (cnt > 0) {
409 sham::DeviceBuffer<T> buf(cnt * nvar, serializer.get_device_scheduler());
410 serializer.load_buf(buf, cnt * nvar);
411 return PatchDataField<T>(std::move(buf), field_name, nvar);
412 } else {
413 return PatchDataField<T>(field_name, nvar, cnt);
414 }
415}
416
417template<class T>
419
421 return H::serialize_byte_size<u32>() + H::serialize_byte_size<T>(get_val_cnt());
422}
423
424template<class T>
426 StackEntry stack_loc{false};
427 serializer.write(nvar);
428 serializer.write(field_name);
429 serialize_buf(serializer);
430}
431
432template<class T>
435 return (H::serialize_byte_size<u32>()) + H::serialize_byte_size(field_name)
436 + serialize_buf_byte_size();
437}
438
439template<class T>
441 StackEntry stack_loc{false};
442 u32 nvar;
443 serializer.load(nvar);
444 std::string field_name;
445 serializer.load(field_name);
446
447 return deserialize_buf(serializer, field_name, nvar);
448}
449
450template<class T>
452 StackEntry stack_loc{};
453 if (is_empty()) {
455 }
456
457 auto dev_sched = shamsys::instance::get_compute_scheduler_ptr();
458 return shamalgs::primitives::max(dev_sched, buf, 0, get_val_cnt());
459}
460
461template<class T>
463 StackEntry stack_loc{};
464 if (is_empty()) {
466 }
467
468 auto dev_sched = shamsys::instance::get_compute_scheduler_ptr();
469 return shamalgs::primitives::min(dev_sched, buf, 0, get_val_cnt());
470}
471
472template<class T>
474 StackEntry stack_loc{};
475 if (is_empty()) {
477 }
478
479 auto dev_sched = shamsys::instance::get_compute_scheduler_ptr();
480 return shamalgs::primitives::sum(dev_sched, buf, 0, get_val_cnt());
481}
482
483template<class T>
484shambase::VecComponent<T> PatchDataField<T>::compute_dot_sum() {
485 StackEntry stack_loc{};
486 if (is_empty()) {
488 }
489
490 return shamalgs::primitives::dot_sum(buf, 0, get_val_cnt());
491}
492
493template<class T>
495 StackEntry stack_loc{};
496
497 auto tmp = buf.copy_to_sycl_buffer();
498
499 return shamalgs::reduction::has_nan(shamsys::instance::get_compute_queue(), tmp, get_val_cnt());
500}
501template<class T>
503 StackEntry stack_loc{};
505 auto tmp = buf.copy_to_sycl_buffer();
506
507 return shamalgs::reduction::has_inf(shamsys::instance::get_compute_queue(), tmp, get_val_cnt());
508}
509template<class T>
511 StackEntry stack_loc{};
512
513 auto tmp = buf.copy_to_sycl_buffer();
514
515 return shamalgs::reduction::has_nan_or_inf(
516 shamsys::instance::get_compute_queue(), tmp, get_val_cnt());
517}
518
520// Define the patchdata field for all classes in XMAC_LIST_ENABLED_FIELD
522
523#ifndef DOXYGEN
524 #define X(a) template class PatchDataField<a>;
525XMAC_LIST_ENABLED_FIELD
526 #undef X
527#endif
530
532// data mocking for patchdata field
534
535const u32 obj_mock_cnt = 6000;
537#ifndef DOXYGEN
538template<>
539void PatchDataField<f32>::gen_mock_data(u32 obj_cnt, std::mt19937 &eng) {
540 resize(obj_cnt);
541
542 std::vector<f32> out(obj_cnt * nvar);
543 std::uniform_real_distribution<f64> distf64(1, obj_mock_cnt);
544
545 for (u32 i = 0; i < get_val_cnt(); i++) {
546 out[i] = f32(distf64(eng));
547 }
548
549 buf.copy_from_stdvec(out);
550}
551
552template<>
553void PatchDataField<f32_2>::gen_mock_data(u32 obj_cnt, std::mt19937 &eng) {
554 resize(obj_cnt);
555 std::uniform_real_distribution<f64> distf64(1, obj_mock_cnt);
556
557 std::vector<f32_2> out(obj_cnt * nvar);
558
559 for (u32 i = 0; i < get_val_cnt(); i++) {
560 out[i] = f32_2{distf64(eng), distf64(eng)};
562 buf.copy_from_stdvec(out);
563}
564
565template<>
566void PatchDataField<f32_3>::gen_mock_data(u32 obj_cnt, std::mt19937 &eng) {
567 resize(obj_cnt);
568 std::uniform_real_distribution<f64> distf64(1, obj_mock_cnt);
570 std::vector<f32_3> out(obj_cnt * nvar);
571
572 for (u32 i = 0; i < get_val_cnt(); i++) {
573 out[i] = f32_3{distf64(eng), distf64(eng), distf64(eng)};
574 }
575 buf.copy_from_stdvec(out);
577
578template<>
579void PatchDataField<f32_4>::gen_mock_data(u32 obj_cnt, std::mt19937 &eng) {
580 resize(obj_cnt);
581 std::uniform_real_distribution<f64> distf64(1, obj_mock_cnt);
582
583 std::vector<f32_4> out(obj_cnt * nvar);
584
585 for (u32 i = 0; i < get_val_cnt(); i++) {
586 out[i] = f32_4{distf64(eng), distf64(eng), distf64(eng), distf64(eng)};
587 }
588 buf.copy_from_stdvec(out);
589}
590
591template<>
592void PatchDataField<f32_8>::gen_mock_data(u32 obj_cnt, std::mt19937 &eng) {
593 resize(obj_cnt);
594 std::uniform_real_distribution<f64> distf64(1, obj_mock_cnt);
595
596 std::vector<f32_8> out(obj_cnt * nvar);
597
598 for (u32 i = 0; i < get_val_cnt(); i++) {
599 out[i] = f32_8{
600 distf64(eng),
601 distf64(eng),
602 distf64(eng),
603 distf64(eng),
604 distf64(eng),
605 distf64(eng),
606 distf64(eng),
607 distf64(eng)};
608 }
609 buf.copy_from_stdvec(out);
610}
611
612template<>
613void PatchDataField<f32_16>::gen_mock_data(u32 obj_cnt, std::mt19937 &eng) {
614 resize(obj_cnt);
615 std::uniform_real_distribution<f64> distf64(1, obj_mock_cnt);
616
617 std::vector<f32_16> out(obj_cnt * nvar);
618
619 for (u32 i = 0; i < get_val_cnt(); i++) {
620 out[i] = f32_16{
621 distf64(eng),
622 distf64(eng),
623 distf64(eng),
624 distf64(eng),
625 distf64(eng),
626 distf64(eng),
627 distf64(eng),
628 distf64(eng),
629 distf64(eng),
630 distf64(eng),
631 distf64(eng),
632 distf64(eng),
633 distf64(eng),
634 distf64(eng),
635 distf64(eng),
636 distf64(eng)};
637 }
638 buf.copy_from_stdvec(out);
639}
640
641template<>
642void PatchDataField<f64>::gen_mock_data(u32 obj_cnt, std::mt19937 &eng) {
643 resize(obj_cnt);
644 std::uniform_real_distribution<f64> distf64(1, obj_mock_cnt);
645
646 std::vector<f64> out(obj_cnt * nvar);
647
648 for (u32 i = 0; i < get_val_cnt(); i++) {
649 out[i] = f64(distf64(eng));
650 }
651 buf.copy_from_stdvec(out);
652}
653
654template<>
655void PatchDataField<f64_2>::gen_mock_data(u32 obj_cnt, std::mt19937 &eng) {
656 resize(obj_cnt);
657 std::uniform_real_distribution<f64> distf64(1, obj_mock_cnt);
658
659 std::vector<f64_2> out(obj_cnt * nvar);
660
661 for (u32 i = 0; i < get_val_cnt(); i++) {
662 out[i] = f64_2{distf64(eng), distf64(eng)};
663 }
664 buf.copy_from_stdvec(out);
665}
666
667template<>
668void PatchDataField<f64_3>::gen_mock_data(u32 obj_cnt, std::mt19937 &eng) {
669 resize(obj_cnt);
670 std::uniform_real_distribution<f64> distf64(1, obj_mock_cnt);
671
672 std::vector<f64_3> out(obj_cnt * nvar);
673
674 for (u32 i = 0; i < get_val_cnt(); i++) {
675 out[i] = f64_3{distf64(eng), distf64(eng), distf64(eng)};
676 }
677 buf.copy_from_stdvec(out);
678}
679
680template<>
681void PatchDataField<f64_4>::gen_mock_data(u32 obj_cnt, std::mt19937 &eng) {
682 resize(obj_cnt);
683 std::uniform_real_distribution<f64> distf64(1, obj_mock_cnt);
684
685 std::vector<f64_4> out(obj_cnt * nvar);
686
687 for (u32 i = 0; i < get_val_cnt(); i++) {
688 out[i] = f64_4{distf64(eng), distf64(eng), distf64(eng), distf64(eng)};
689 }
690 buf.copy_from_stdvec(out);
691}
692
693template<>
694void PatchDataField<f64_8>::gen_mock_data(u32 obj_cnt, std::mt19937 &eng) {
695 resize(obj_cnt);
696 std::uniform_real_distribution<f64> distf64(1, obj_mock_cnt);
697
698 std::vector<f64_8> out(obj_cnt * nvar);
699
700 for (u32 i = 0; i < get_val_cnt(); i++) {
701 out[i] = f64_8{
702 distf64(eng),
703 distf64(eng),
704 distf64(eng),
705 distf64(eng),
706 distf64(eng),
707 distf64(eng),
708 distf64(eng),
709 distf64(eng)};
710 }
711 buf.copy_from_stdvec(out);
712}
713
714template<>
715void PatchDataField<f64_16>::gen_mock_data(u32 obj_cnt, std::mt19937 &eng) {
716 resize(obj_cnt);
717 std::uniform_real_distribution<f64> distf64(1, obj_mock_cnt);
718
719 std::vector<f64_16> out(obj_cnt * nvar);
720
721 for (u32 i = 0; i < get_val_cnt(); i++) {
722 out[i] = f64_16{
723 distf64(eng),
724 distf64(eng),
725 distf64(eng),
726 distf64(eng),
727 distf64(eng),
728 distf64(eng),
729 distf64(eng),
730 distf64(eng),
731 distf64(eng),
732 distf64(eng),
733 distf64(eng),
734 distf64(eng),
735 distf64(eng),
736 distf64(eng),
737 distf64(eng),
738 distf64(eng)};
739 }
740 buf.copy_from_stdvec(out);
741}
742
743template<>
744void PatchDataField<u32>::gen_mock_data(u32 obj_cnt, std::mt19937 &eng) {
745 resize(obj_cnt);
746 std::uniform_int_distribution<u32> distu32(1, obj_mock_cnt);
747
748 std::vector<u32> out(obj_cnt * nvar);
749
750 for (u32 i = 0; i < get_val_cnt(); i++) {
751 out[i] = distu32(eng);
752 }
753 buf.copy_from_stdvec(out);
754}
755template<>
756void PatchDataField<u64>::gen_mock_data(u32 obj_cnt, std::mt19937 &eng) {
757 resize(obj_cnt);
758 std::uniform_int_distribution<u64> distu64(1, obj_mock_cnt);
759
760 std::vector<u64> out(obj_cnt * nvar);
761
762 for (u32 i = 0; i < get_val_cnt(); i++) {
763 out[i] = distu64(eng);
764 }
765 buf.copy_from_stdvec(out);
766}
767
768template<>
769void PatchDataField<u32_3>::gen_mock_data(u32 obj_cnt, std::mt19937 &eng) {
770 resize(obj_cnt);
771 std::uniform_int_distribution<u32> distu32(1, obj_mock_cnt);
772
773 std::vector<u32_3> out(obj_cnt * nvar);
774
775 for (u32 i = 0; i < get_val_cnt(); i++) {
776 out[i] = u32_3{distu32(eng), distu32(eng), distu32(eng)};
777 }
778 buf.copy_from_stdvec(out);
779}
780template<>
781void PatchDataField<u64_3>::gen_mock_data(u32 obj_cnt, std::mt19937 &eng) {
782 resize(obj_cnt);
783 std::uniform_int_distribution<u64> distu64(1, obj_mock_cnt);
784
785 std::vector<u64_3> out(obj_cnt * nvar);
786
787 for (u32 i = 0; i < get_val_cnt(); i++) {
788 out[i] = u64_3{distu64(eng), distu64(eng), distu64(eng)};
789 }
790 buf.copy_from_stdvec(out);
791}
792
793template<>
794void PatchDataField<i64_3>::gen_mock_data(u32 obj_cnt, std::mt19937 &eng) {
795 resize(obj_cnt);
796 std::uniform_int_distribution<i64> disti64(1, obj_mock_cnt);
797
798 std::vector<i64_3> out(obj_cnt * nvar);
799
800 for (u32 i = 0; i < get_val_cnt(); i++) {
801 out[i] = i64_3{disti64(eng), disti64(eng), disti64(eng)};
802 }
803 buf.copy_from_stdvec(out);
804}
805
806template<>
807void PatchDataField<i64>::gen_mock_data(u32 obj_cnt, std::mt19937 &eng) {
808 resize(obj_cnt);
809 std::uniform_int_distribution<i64> disti64(1, obj_mock_cnt);
810
811 std::vector<i64> out(obj_cnt * nvar);
812
813 for (u32 i = 0; i < get_val_cnt(); i++) {
814 out[i] = i64{disti64(eng)};
815 }
816 buf.copy_from_stdvec(out);
817}
818#endif
Header file describing a Node Instance.
double f64
Alias for double.
float f32
Alias for float.
std::uint32_t u32
32 bit unsigned integer
std::uint64_t u64
64 bit unsigned integer
std::int64_t i64
64 bit integer
shamalgs::SerializeSize serialize_full_byte_size()
give the size usage of serialize_full
void permut_vars(const std::vector< u32 > &permut)
permut the variables of the field according to the permut
void index_remap(sham::DeviceBuffer< u32 > &index_map, u32 len)
this function remaps the patchdatafield like so val[id] = val[index_map[id]] index map describe : at ...
void remove_ids(const sham::DeviceBuffer< u32 > &indexes, u32 len)
remove the ids from the field
static PatchDataField deserialize_buf(shamalgs::SerializeHelper &serializer, std::string field_name, u32 nvar)
deserialize a field inverse of serialize_buf
void append_subset_to(const std::vector< u32 > &idxs, PatchDataField &pfield)
Copy all objects in idxs to pfield.
shamalgs::SerializeSize serialize_buf_byte_size()
record the size usage of the serialization using serialize_buf
u32 get_val_cnt() const
Get the number of values stored in the field.
static PatchDataField deserialize_full(shamalgs::SerializeHelper &serializer)
deserialize a field inverse of serialize_full
void index_remap_resize(sham::DeviceBuffer< u32 > &index_map, u32 len)
this function remaps the patchdatafield like so val[id] = val[index_map[id]] index map describe : at ...
void serialize_buf(shamalgs::SerializeHelper &serializer)
minimal serialization assuming the user know the layout of the field
void serialize_full(shamalgs::SerializeHelper &serializer)
serialize everything in the class
A buffer allocated in USM (Unified Shared Memory)
void copy_from_stdvec(const std::vector< T > &vec)
Copy the content of a std::vector into the buffer.
void fill(T value, std::array< size_t, 2 > idx_range)
Fill a subpart of the buffer with a given value.
std::vector< T > copy_to_stdvec() const
Copy the content of the buffer to a std::vector.
size_t get_size() const
Gets the number of elements in the buffer.
std::vector< T > copy_to_stdvec_idx_range(size_t begin, size_t end) const
Copies a specified range of elements from the buffer to a std::vector.
DeviceBuffer< T, target > copy() const
Copy the current buffer.
A SYCL queue associated with a device and a context.
sycl::event submit(Fct &&fct)
Submits a kernel to the SYCL queue.
DeviceQueue & get_queue(u32 id=0)
Get a reference to a DeviceQueue.
Class to manage a list of SYCL events.
Definition EventList.hpp:31
Provides functions to compute the sum of dot products of elements in a device buffer with themselves.
Element-wise equality comparison algorithms for buffers.
This header file contains utility functions related to exception handling in the code.
Utility functions for generating random mock values.
Utility functions for generating random mock vectors.
void kernel_call(sham::DeviceQueue &q, RefIn in, RefOut in_out, u32 n, Functor &&func, SourceLocation &&callsite=SourceLocation{})
Submit a kernel to a SYCL queue.
sycl::buffer< T > index_remap(sycl::queue &q, sycl::buffer< T > &source_buf, sycl::buffer< u32 > &index_map, u32 len)
remap a buffer according to a given index map result[i] = result[index_map[i]]
Definition algorithm.cpp:32
sycl::buffer< T > index_remap_nvar(sycl::queue &q, sycl::buffer< T > &source_buf, sycl::buffer< u32 > &index_map, u32 len, u32 nvar)
remap a buffer (with multiple variable per index) according to a given index map result[i] = result[i...
Definition algorithm.cpp:51
std::tuple< std::optional< sycl::buffer< u32 > >, u32 > stream_compact(sycl::queue &q, sycl::buffer< u32 > &buf_flags, u32 len)
Stream compaction algorithm.
Definition numeric.cpp:84
T sum(const sham::DeviceScheduler_ptr &sched, const sham::DeviceBuffer< T > &buf1, u32 start_id, u32 end_id)
Compute the sum of elements in a device buffer within a specified range.
bool equals(sycl::queue &q, sycl::buffer< T > &buf1, sycl::buffer< T > &buf2, u32 cnt)
Compare elements between two sycl::buffers for equality.
Definition equals.hpp:77
shambase::VecComponent< T > dot_sum(sham::DeviceBuffer< T > &buf1, u32 start_id, u32 end_id)
Compute the sum of dot products of elements in a device buffer with themselves.
Definition dot_sum.cpp:26
T min(const sham::DeviceScheduler_ptr &sched, const sham::DeviceBuffer< T > &buf1, u32 start_id, u32 end_id)
Find the minimum element in a device buffer within a specified range.
T max(const sham::DeviceScheduler_ptr &sched, const sham::DeviceBuffer< T > &buf1, u32 start_id, u32 end_id)
Find the maximum element in a device buffer within a specified range.
void append_subset_to(const sham::DeviceBuffer< T > &buf, const sham::DeviceBuffer< u32 > &idxs_buf, u32 nvar, sham::DeviceBuffer< T > &buf_other, u32 start_enque)
Appends a subset of elements from one buffer to another.
void throw_with_loc(std::string message, SourceLocation loc=SourceLocation{})
Throw an exception and append the source location to it.
void throw_unimplemented(SourceLocation loc=SourceLocation{})
Throw a std::runtime_error saying that the function is unimplemented.
Utilities for safe type narrowing conversions.
A class that references multiple buffers or similar objects.