26namespace shamalgs::numeric::details {
28 template<
class T, u32 group_size>
31 template<
class T, u32 group_size>
32 sycl::buffer<T> exclusive_sum_atomic2pass(sycl::queue &q, sycl::buffer<T> &buf1,
u32 len) {
35 u32 corrected_len = group_cnt * group_size;
37 sycl::buffer<T> ret_buf(len);
39 q.submit([&, len](sycl::handler &cgh) {
40 sycl::accessor acc_in{buf1, cgh, sycl::read_only};
41 sycl::accessor acc_out{ret_buf, cgh, sycl::write_only, sycl::no_init};
43 cgh.parallel_for(sycl::range<1>{len}, [=](sycl::item<1> id) {
44 u32 thid =
id.get_linear_id();
45 acc_out[id] = (thid > 0) ? acc_in[thid - 1] : 0;
52 atomic::DynamicIdGenerator<i32, group_size> id_gen(q);
54 atomic::DeviceCounter<i32> device_count(q);
55 atomic::DeviceCounter<u32> global_summation(q);
57 q.submit([&, group_cnt, len](sycl::handler &cgh) {
58 sycl::accessor value_buffer{ret_buf, cgh, sycl::read_write};
60 auto dyn_id = id_gen.get_access(cgh);
61 auto device_counter = device_count.get_access(cgh);
62 auto global_sum = global_summation.get_access(cgh);
64 sycl::local_accessor<T, 1> local_scan_buf{1, cgh};
65 sycl::local_accessor<T, 1> local_sum{1, cgh};
67 cgh.parallel_for<KernelExclusiveSumAtomicSync<T, group_size>>(
68 sycl::nd_range<1>{corrected_len, group_size}, [=](sycl::nd_item<1> id) {
69 atomic::DynamicId<i32> group_id = dyn_id.compute_id(
id);
73 if (group_id.dyn_global_id < len) {
74 local_val = value_buffer[group_id.dyn_global_id];
81 T local_scan = sycl::inclusive_scan_over_group(
82 id.get_group(), local_val, sycl::plus<T>{});
84 if (
id.get_local_id(0) == group_size - 1) {
85 local_scan_buf[0] = local_scan;
89 id.barrier(sycl::access::fence_space::local_space);
92 if (group_id.is_main_thread) {
95 sycl::atomic_ref atomic_counter
96 = device_counter.attach_atomic<sycl::memory_order_acq_rel>();
97 sycl::atomic_ref atomic_sum
98 = global_sum.attach_atomic<sycl::memory_order_relaxed>();
101 T group_sum = local_scan_buf[0];
105 if (group_id.dyn_group_id == 0) {
108 atomic_sum += group_sum;
113 while (atomic_counter.load() != group_id.dyn_group_id) {
116 T exclusive_group_prefix_sum = atomic_sum.fetch_add(group_sum);
119 local_sum[0] = exclusive_group_prefix_sum;
124 id.barrier(sycl::access::fence_space::local_space);
127 if (group_id.dyn_global_id < len) {
128 value_buffer[group_id.dyn_global_id] = local_scan + local_sum[0];
137 template<
class T, u32 group_size>
140 template<
class T, u32 group_size>
141 sycl::buffer<T> exclusive_sum_atomic2pass_v2(sycl::queue &q, sycl::buffer<T> &buf1,
u32 len) {
144 u32 corrected_len = group_cnt * group_size;
146 sycl::buffer<T> ret_buf(len);
148 q.submit([&, len](sycl::handler &cgh) {
149 sycl::accessor acc_in{buf1, cgh, sycl::read_only};
150 sycl::accessor acc_out{ret_buf, cgh, sycl::write_only, sycl::no_init};
152 cgh.parallel_for(sycl::range<1>{len}, [=](sycl::item<1> id) {
153 u32 thid =
id.get_linear_id();
154 acc_out[id] = (thid > 0) ? acc_in[thid - 1] : 0;
162 sycl::buffer<T> aggregates(group_cnt);
164 atomic::DynamicIdGenerator<i32, group_size> id_gen(q);
166 atomic::DeviceCounter<i32> device_count(q);
168 q.submit([&, group_cnt, len](sycl::handler &cgh) {
169 sycl::accessor value_buffer{ret_buf, cgh, sycl::read_write};
171 auto dyn_id = id_gen.get_access(cgh);
172 auto device_counter = device_count.get_access(cgh);
174 sycl::accessor acc_gsum{aggregates, cgh, sycl::read_write};
176 sycl::local_accessor<T, 1> local_scan_buf{1, cgh};
177 sycl::local_accessor<T, 1> local_sum{1, cgh};
179 cgh.parallel_for<KernelExclusiveSumAtomicSync_v2<T, group_size>>(
180 sycl::nd_range<1>{corrected_len, group_size}, [=](sycl::nd_item<1> id) {
181 atomic::DynamicId<i32> group_id = dyn_id.compute_id(
id);
185 if (group_id.dyn_global_id < len) {
186 local_val = value_buffer[group_id.dyn_global_id];
193 T local_scan = sycl::inclusive_scan_over_group(
194 id.get_group(), local_val, sycl::plus<T>{});
196 if (
id.get_local_id(0) == group_size - 1) {
197 local_scan_buf[0] = local_scan;
201 id.barrier(sycl::access::fence_space::local_space);
204 if (group_id.is_main_thread) {
207 sycl::atomic_ref atomic_counter
208 = device_counter.attach_atomic<sycl::memory_order_acq_rel>();
211 T group_sum = local_scan_buf[0];
215 using atomic_ref_T = sycl::atomic_ref<
217 sycl::memory_order_relaxed,
218 sycl::memory_scope_device,
219 sycl::access::address_space::global_space>;
221 if (group_id.dyn_group_id == 0) {
224 atomic_ref_T(acc_gsum[0]).store(group_sum);
230 while (atomic_counter.load() != group_id.dyn_group_id) {
233 T exclusive_group_prefix_sum
234 = atomic_ref_T(acc_gsum[group_id.dyn_group_id - 1]).load();
236 atomic_ref_T(acc_gsum[group_id.dyn_group_id])
237 .store(exclusive_group_prefix_sum + group_sum);
240 local_sum[0] = exclusive_group_prefix_sum;
245 id.barrier(sycl::access::fence_space::local_space);
248 if (group_id.dyn_global_id < len) {
249 value_buffer[group_id.dyn_global_id] = local_scan + local_sum[0];
258 template<
class T, u32 group_size>
261 template<
class T, u32 group_size>
262 sycl::buffer<T> exclusive_sum_atomic_decoupled(sycl::queue &q, sycl::buffer<T> &buf1,
u32 len) {
265 u32 corrected_len = group_cnt * group_size;
268 sycl::buffer<T> ret_buf(len);
270 q.submit([&, len](sycl::handler &cgh) {
271 sycl::accessor acc_in{buf1, cgh, sycl::read_only};
272 sycl::accessor acc_out{ret_buf, cgh, sycl::write_only, sycl::no_init};
274 cgh.parallel_for(sycl::range<1>{len}, [=](sycl::item<1> id) {
275 u32 thid =
id.get_linear_id();
276 acc_out[id] = (thid > 0) ? acc_in[thid - 1] : 0;
284 sycl::buffer<i32> tile_state(group_cnt);
285 sycl::buffer<T> tile_aggregates(group_cnt);
286 sycl::buffer<T> tile_incl_prefix(group_cnt);
288 constexpr i32 STATE_X = 0;
289 constexpr i32 STATE_A = 1;
290 constexpr i32 STATE_P = 2;
296 atomic::DynamicIdGenerator<i32, group_size> id_gen(q);
298 q.submit([&, group_cnt, len](sycl::handler &cgh) {
299 sycl::accessor acc_value{ret_buf, cgh, sycl::read_write};
300 sycl::accessor acc_tile_state{tile_state, cgh, sycl::read_write};
301 sycl::accessor acc_tile_aggregates{tile_aggregates, cgh, sycl::read_write};
302 sycl::accessor acc_tile_incl_prefix{tile_incl_prefix, cgh, sycl::read_write};
304 auto dyn_id = id_gen.get_access(cgh);
306 sycl::local_accessor<T, 1> local_scan_buf{1, cgh};
307 sycl::local_accessor<T, 1> local_sum{1, cgh};
309 using atomic_ref_state = sycl::atomic_ref<
311 sycl::memory_order_relaxed,
312 sycl::memory_scope_device,
313 sycl::access::address_space::global_space>;
315 using atomic_ref_T = sycl::atomic_ref<
317 sycl::memory_order_relaxed,
318 sycl::memory_scope_device,
319 sycl::access::address_space::global_space>;
321 cgh.parallel_for<KernelExclusiveSumAtomicSyncDecoupled<T, group_size>>(
322 sycl::nd_range<1>{corrected_len, group_size}, [=](sycl::nd_item<1> id) {
323 atomic::DynamicId<i32> group_id = dyn_id.compute_id(
id);
327 if (group_id.dyn_global_id < len) {
328 local_val = acc_value[group_id.dyn_global_id];
335 T local_scan = sycl::inclusive_scan_over_group(
336 id.get_group(), local_val, sycl::plus<T>{});
338 if (
id.get_local_id(0) == group_size - 1) {
339 local_scan_buf[0] = local_scan;
343 id.barrier(sycl::access::fence_space::local_space);
346 if (group_id.is_main_thread) {
349 T local_group_sum = local_scan_buf[0];
354 if (group_id.dyn_group_id != 0) {
356 atomic_ref_T(acc_tile_aggregates[group_id.dyn_group_id])
357 .store(local_group_sum);
358 atomic_ref_state(acc_tile_state[group_id.dyn_group_id]).store(STATE_A);
360 u32 tile_ptr = group_id.dyn_group_id - 1;
363 i32 tstate = atomic_ref_state(acc_tile_state[tile_ptr]).load();
365 if (tstate == STATE_X) {
369 if (tstate == STATE_A) {
370 accum += atomic_ref_T(acc_tile_aggregates[tile_ptr]).load();
373 if (tstate == STATE_P) {
374 accum += atomic_ref_T(acc_tile_incl_prefix[tile_ptr]).load();
382 atomic_ref_T(acc_tile_incl_prefix[group_id.dyn_group_id])
383 .store(accum + local_group_sum);
384 atomic_ref_state(acc_tile_state[group_id.dyn_group_id]).store(STATE_P);
386 local_sum[0] = accum;
390 id.barrier(sycl::access::fence_space::local_space);
393 if (group_id.dyn_global_id < len) {
394 acc_value[group_id.dyn_global_id] = local_scan + local_sum[0];
403 template<
class T, u32 group_size>
406 template<
class T, u32 group_size>
407 sycl::buffer<T> exclusive_sum_atomic_decoupled_v2(
408 sycl::queue &q, sycl::buffer<T> &buf1,
u32 len) {
411 u32 corrected_len = group_cnt * group_size;
414 sycl::buffer<T> ret_buf(len);
416 q.submit([&, len](sycl::handler &cgh) {
417 sycl::accessor acc_in{buf1, cgh, sycl::read_only};
418 sycl::accessor acc_out{ret_buf, cgh, sycl::write_only, sycl::no_init};
420 cgh.parallel_for(sycl::range<1>{len}, [=](sycl::item<1> id) {
421 u32 thid =
id.get_linear_id();
422 acc_out[id] = (thid > 0) ? acc_in[thid - 1] : 0;
430 sycl::buffer<u64> tile_state(group_cnt);
432 constexpr T STATE_X = 0;
433 constexpr T STATE_A = 1;
434 constexpr T STATE_P = 2;
438 atomic::DynamicIdGenerator<i32, group_size> id_gen(q);
440 q.submit([&, group_cnt, len](sycl::handler &cgh) {
441 sycl::accessor acc_value{ret_buf, cgh, sycl::read_write};
442 sycl::accessor acc_tile_state{tile_state, cgh, sycl::read_write};
444 auto dyn_id = id_gen.get_access(cgh);
446 sycl::local_accessor<T, 1> local_scan_buf{1, cgh};
447 sycl::local_accessor<T, 1> local_sum{1, cgh};
449 using atomic_ref_T = sycl::atomic_ref<
451 sycl::memory_order_relaxed,
452 sycl::memory_scope_device,
453 sycl::access::address_space::global_space>;
455 cgh.parallel_for<KernelExclusiveSumAtomicSyncDecoupled_v2<T, group_size>>(
456 sycl::nd_range<1>{corrected_len, group_size}, [=](sycl::nd_item<1> id) {
457 atomic::DynamicId<i32> group_id = dyn_id.compute_id(
id);
461 if (group_id.dyn_global_id < len) {
462 local_val = acc_value[group_id.dyn_global_id];
469 T local_scan = sycl::inclusive_scan_over_group(
470 id.get_group(), local_val, sycl::plus<T>{});
472 if (
id.get_local_id(0) == group_size - 1) {
473 local_scan_buf[0] = local_scan;
477 id.barrier(sycl::access::fence_space::local_space);
479 auto store = [=](
u32 id, T state, T val) {
480 atomic_ref_T(acc_tile_state[
id]).store(sham::pack32(state, val));
483 auto load = [=](
u32 id) -> sycl::vec<T, 2> {
484 return sham::unpack32(atomic_ref_T(acc_tile_state[
id]).load());
488 if (group_id.is_main_thread) {
491 T local_group_sum = local_scan_buf[0];
493 u32 tile_ptr = group_id.dyn_group_id - 1;
497 if (group_id.dyn_group_id != 0) {
499 store(group_id.dyn_group_id, STATE_A, local_group_sum);
503 sycl::vec<T, 2> state = load(tile_ptr);
505 if (state.x() == STATE_X) {
509 if (state.x() == STATE_A) {
513 if (state.x() == STATE_P) {
522 store(group_id.dyn_group_id, STATE_P, accum + local_group_sum);
524 local_sum[0] = accum;
528 id.barrier(sycl::access::fence_space::local_space);
531 if (group_id.dyn_global_id < len) {
532 acc_value[group_id.dyn_global_id] = local_scan + local_sum[0];
541 template<
class T, u32 group_size>
544 template<
class T, u32 group_size>
545 sycl::buffer<T> exclusive_sum_atomic_decoupled_v3(
546 sycl::queue &q, sycl::buffer<T> &buf1,
u32 len) {
549 u32 corrected_len = group_cnt * group_size;
552 sycl::buffer<T> ret_buf(len);
554 q.submit([&, len](sycl::handler &cgh) {
555 sycl::accessor acc_in{buf1, cgh, sycl::read_only};
556 sycl::accessor acc_out{ret_buf, cgh, sycl::write_only, sycl::no_init};
558 cgh.parallel_for(sycl::range<1>{len}, [=](sycl::item<1> id) {
559 u32 thid =
id.get_linear_id();
560 acc_out[id] = (thid > 0) ? acc_in[thid - 1] : 0;
568 sycl::buffer<u64> tile_state(group_cnt);
570 constexpr T STATE_X = 0;
571 constexpr T STATE_A = 1;
572 constexpr T STATE_P = 2;
576 atomic::DynamicIdGenerator<i32, group_size> id_gen(q);
578 q.submit([&, group_cnt, len](sycl::handler &cgh) {
579 sycl::accessor acc_value{ret_buf, cgh, sycl::read_write};
580 sycl::accessor acc_tile_state{tile_state, cgh, sycl::read_write};
582 auto dyn_id = id_gen.get_access(cgh);
584 sycl::local_accessor<T, 1> local_scan_buf{1, cgh};
585 sycl::local_accessor<T, 1> local_sum{1, cgh};
587 using atomic_ref_T = sycl::atomic_ref<
589 sycl::memory_order_relaxed,
590 sycl::memory_scope_work_group,
591 sycl::access::address_space::global_space>;
593 cgh.parallel_for<KernelExclusiveSumAtomicSyncDecoupled_v3<T, group_size>>(
594 sycl::nd_range<1>{corrected_len, group_size}, [=](sycl::nd_item<1> id) {
595 atomic::DynamicId<i32> group_id = dyn_id.compute_id(
id);
599 if (group_id.dyn_global_id < len) {
600 local_val = acc_value[group_id.dyn_global_id];
607 T local_scan = sycl::inclusive_scan_over_group(
608 id.get_group(), local_val, sycl::plus<T>{});
610 if (
id.get_local_id(0) == group_size - 1) {
611 local_scan_buf[0] = local_scan;
615 id.barrier(sycl::access::fence_space::local_space);
617 auto store = [=](
u32 id, T state, T val) {
618 atomic_ref_T(acc_tile_state[
id]).store(sham::pack32(state, val));
621 auto load = [=](
u32 id) -> sycl::vec<T, 2> {
622 return sham::unpack32(atomic_ref_T(acc_tile_state[
id]).load());
626 if (group_id.is_main_thread) {
629 T local_group_sum = local_scan_buf[0];
631 u32 tile_ptr = group_id.dyn_group_id - 1;
635 if (group_id.dyn_group_id != 0) {
637 store(group_id.dyn_group_id, STATE_A, local_group_sum);
641 sycl::vec<T, 2> state = load(tile_ptr);
643 if (state.x() == STATE_X) {
647 if (state.x() == STATE_A) {
651 if (state.x() == STATE_P) {
660 store(group_id.dyn_group_id, STATE_P, accum + local_group_sum);
662 local_sum[0] = accum;
666 id.barrier(sycl::access::fence_space::local_space);
669 if (group_id.dyn_global_id < len) {
670 acc_value[group_id.dyn_global_id] = local_scan + local_sum[0];
679 template<
class T, u32 group_size>
682 template<
class T, u32 group_size>
683 sycl::buffer<T> exclusive_sum_atomic_decoupled_v4(
684 sycl::queue &q, sycl::buffer<T> &buf1,
u32 len) {
688 group_cnt = group_cnt + (group_cnt % 4);
689 u32 corrected_len = group_cnt * group_size;
692 sycl::buffer<T> ret_buf(corrected_len);
694 q.submit([&, len](sycl::handler &cgh) {
695 sycl::accessor acc_in{buf1, cgh, sycl::read_only};
696 sycl::accessor acc_out{ret_buf, cgh, sycl::write_only, sycl::no_init};
698 cgh.parallel_for(sycl::range<1>{corrected_len}, [=](sycl::item<1> id) {
699 u32 thid =
id.get_linear_id();
700 acc_out[id] = (thid > 0 && thid < len) ? acc_in[thid - 1] : 0;
708 sycl::buffer<u64> tile_state(group_cnt);
710 constexpr T STATE_X = 0;
711 constexpr T STATE_A = 1;
712 constexpr T STATE_P = 2;
716 atomic::DynamicIdGenerator<i32, group_size> id_gen(q);
718 q.submit([&, group_cnt, len](sycl::handler &cgh) {
719 sycl::accessor acc_value{ret_buf, cgh, sycl::read_write};
720 sycl::accessor acc_tile_state{tile_state, cgh, sycl::read_write};
722 auto dyn_id = id_gen.get_access(cgh);
724 sycl::local_accessor<T, 1> local_scan_buf{1, cgh};
725 sycl::local_accessor<T, 1> local_sum{1, cgh};
727 using atomic_ref_T = sycl::atomic_ref<
729 sycl::memory_order_relaxed,
730 sycl::memory_scope_work_group,
731 sycl::access::address_space::global_space>;
733 cgh.parallel_for<KernelExclusiveSumAtomicSyncDecoupled_v4<T, group_size>>(
734 sycl::nd_range<1>{corrected_len, group_size}, [=](sycl::nd_item<1> id) {
735 atomic::DynamicId<i32> group_id = dyn_id.compute_id(
id);
738 T local_val = acc_value[group_id.dyn_global_id];
742 T local_scan = sycl::inclusive_scan_over_group(
743 id.get_group(), local_val, sycl::plus<T>{});
745 if (
id.get_local_id(0) == group_size - 1) {
746 local_scan_buf[0] = local_scan;
750 id.barrier(sycl::access::fence_space::local_space);
752 auto store = [=](
u32 id, T state, T val) {
753 atomic_ref_T(acc_tile_state[
id]).store(sham::pack32(state, val));
756 auto load = [=](
u32 id) -> sycl::vec<T, 2> {
757 return sham::unpack32(atomic_ref_T(acc_tile_state[
id]).load());
761 if (group_id.is_main_thread) {
764 T local_group_sum = local_scan_buf[0];
766 u32 tile_ptr = group_id.dyn_group_id - 1;
767 sycl::vec<T, 2> tile_state = {STATE_X, 0};
771 if (group_id.dyn_group_id != 0) {
773 store(group_id.dyn_group_id, STATE_A, local_group_sum);
775 while (tile_state.x() != STATE_P) {
777 atomic_ref_T atomic_state(acc_tile_state[tile_ptr]);
780 tile_state = sham::unpack32(atomic_state.load());
781 }
while (tile_state.x() == STATE_X);
783 accum += tile_state.y();
789 store(group_id.dyn_group_id, STATE_P, accum + local_group_sum);
791 local_sum[0] = accum;
795 id.barrier(sycl::access::fence_space::local_space);
798 acc_value[group_id.dyn_global_id] = local_scan + local_sum[0];
805 template<
class T, u32 group_size>
808 template<
class T, u32 group_size>
809 sycl::buffer<T> exclusive_sum_sycl_jointalg(sycl::queue &q, sycl::buffer<T> &buf1,
u32 len) {
813 group_cnt = group_cnt + (group_cnt % 4);
814 u32 corrected_len = group_cnt * group_size;
817 sycl::buffer<T> ret_buf(corrected_len);
818 sycl::buffer<T> ret_buf2(corrected_len);
820 q.submit([&, len](sycl::handler &cgh) {
821 sycl::accessor acc_in{buf1, cgh, sycl::read_only};
822 sycl::accessor acc_out{ret_buf, cgh, sycl::write_only, sycl::no_init};
824 cgh.parallel_for(sycl::range<1>{corrected_len}, [=](sycl::item<1> id) {
825 u32 thid =
id.get_linear_id();
826 acc_out[id] = (thid > 0 && thid < len) ? acc_in[thid - 1] : 0;
834 sycl::buffer<u64> tile_state(group_cnt);
836 constexpr T STATE_X = 0;
837 constexpr T STATE_A = 1;
838 constexpr T STATE_P = 2;
842 q.submit([&, group_cnt, len](sycl::handler &cgh) {
843 sycl::accessor acc_in{ret_buf, cgh, sycl::read_write};
844 sycl::accessor acc_out{ret_buf2, cgh, sycl::read_write};
846 cgh.parallel_for<KernelExclusivesum_sycl_jointalg<T, group_size>>(
847 sycl::nd_range<1>{corrected_len, group_size}, [=](sycl::nd_item<1> id) {
848 T *first = &(acc_in[0]);
849 T *last = first + acc_in.size();
851 T *first_out = &(acc_out[0]);
854 sycl::joint_inclusive_scan(
855 id.get_group(), first, last, first_out, sycl::plus<T>{});