28namespace shamalgs::numeric::details {
33 static constexpr T STATE_X = 0;
34 static constexpr T STATE_A = 1;
35 static constexpr T STATE_P = 2;
37 using PackStorage =
u64;
39 sycl::vec<T, 2> state;
43 inline bool has_prefix_available() {
return state.x() == STATE_P; }
45 inline T get_prefix() {
return state.y(); }
47 inline static ScanTile unpack(PackStorage s) {
return ScanTile{sham::unpack32(s)}; }
49 inline static PackStorage pack(T a, T b) {
return sham::pack32(a, b); }
51 inline bool has_no_prefix() {
return state.x() != STATE_P; }
53 inline bool is_invalid() {
return state.x() == STATE_X; }
58 static constexpr u32 STATE_X = 0;
59 static constexpr u32 STATE_A = 1;
60 static constexpr u32 STATE_P = 2;
62 using PackStorage =
u32;
64 sycl::vec<u32, 2> state;
68 inline bool has_prefix_available() {
return state.x() == STATE_P; }
70 inline u32 get_prefix() {
return state.y(); }
74 constexpr u32 mask = (1U << 30U) - 1U;
79 inline static PackStorage pack(
u32 a,
u32 b) {
return (a << 30U) + b; }
81 inline bool has_no_prefix() {
return state.x() != STATE_P; }
83 inline bool is_invalid() {
return state.x() == STATE_X; }
86 enum DecoupledLookBackPolicy { Standard, Parallelized };
88 template<
class T, u32 group_size, DecoupledLookBackPolicy policy,
class Tile>
89 class ScanDecoupledLoockBack;
91 template<
class T, u32 group_size, DecoupledLookBackPolicy policy,
class Tile>
94 sycl::accessor<typename Tile::PackStorage, 1, sycl::access::mode::read_write>
97 sycl::local_accessor<T, 1> local_scan_buf;
98 sycl::local_accessor<T, 1> local_sum;
102 using atomic_ref_T = sycl::atomic_ref<
103 typename Tile::PackStorage,
104 sycl::memory_order_relaxed,
105 sycl::memory_scope_work_group,
106 sycl::access::address_space::global_space>;
112 : acc_tile_state{scan.tile_state, cgh, sycl::read_write}, local_scan_buf{1, cgh},
113 local_sum{1, cgh}, group_count(group_count) {}
115 template<
class InputGetter,
class OutputSetter>
116 inline void decoupled_lookback_scan(
119 const u32 group_tile_id,
122 u32 slice_id = 0)
const {
124 u32 pointer_offset = slice_id * group_count;
128 atomic_ref_T tile_atomic(acc_tile_state[group_tile_id + pointer_offset]);
131 T local_group_sum = input();
133 u32 tile_ptr = group_tile_id - 1;
134 Tile tile_state = Tile::invalid();
138 if (group_tile_id != 0) {
140 tile_atomic.store(Tile::pack(Tile::STATE_A, local_group_sum));
142 while (tile_state.has_no_prefix()) {
144 atomic_ref_T atomic_state(acc_tile_state[tile_ptr + pointer_offset]);
147 tile_state = Tile::unpack(atomic_state.load());
148 }
while (tile_state.is_invalid());
150 accum += tile_state.get_prefix();
156 tile_atomic.store(Tile::pack(Tile::STATE_P, accum + local_group_sum));
162 id.barrier(sycl::access::fence_space::local_space);
168 const u32 group_tile_id,
170 u32 slice_id = 0)
const {
174 T local_scan = sycl::inclusive_scan_over_group(
id.get_group(), input, sycl::plus<T>{});
177 if (local_id == group_size - 1) {
178 local_scan_buf[0] = local_scan;
182 id.barrier(sycl::access::fence_space::local_space);
184 decoupled_lookback_scan(
189 return local_scan_buf[0];
192 local_sum[0] = accum;
196 return local_scan + local_sum[0];
200 template<
class T, u32 group_size, DecoupledLookBackPolicy policy,
class Tile>
206 sycl::buffer<typename Tile::PackStorage> tile_state;
209 : slice_count(slice_count), group_count(group_count),
210 tile_state(group_count * slice_count) {
215 using atomic_ref_T = sycl::atomic_ref<
216 typename Tile::PackStorage,
217 sycl::memory_order_relaxed,
218 sycl::memory_scope_device,
219 sycl::access::address_space::global_space>;
222 sycl::handler &cgh) {
224 cgh, *
this, group_count};
228 template<
class T, u32 group_size>
231 template<
class T, u32 group_size>
232 void exclusive_sum_in_place_atomic_decoupled_v5(
233 sycl::queue &q, sycl::buffer<T> &buf1,
u32 len) {
236 group_cnt = group_cnt + (group_cnt % 4);
237 u32 corrected_len = group_cnt * group_size;
242 q.submit([&, group_cnt, len](sycl::handler &cgh) {
243 sycl::accessor acc_value{buf1, cgh, sycl::read_write};
245 auto scanop = dlookbackscan.get_access(cgh);
248 sycl::nd_range<1>{corrected_len, group_size}, [=](sycl::nd_item<1> id) {
249 u32 local_id =
id.get_local_id(0);
250 u32 group_tile_id =
id.get_group_linear_id();
251 u32 global_id = group_tile_id * group_size + local_id;
254 T local_val = (global_id > 0 && global_id < len) ? acc_value[global_id - 1] : 0;
256 T scanned_value = scanop.scan(
id, local_id, group_tile_id, local_val);
259 if (global_id < len) {
260 acc_value[global_id] = scanned_value;
266 template<
class T, u32 group_size>
269 template<
class T, u32 group_size>
270 sycl::buffer<T> exclusive_sum_atomic_decoupled_v5(
271 sycl::queue &q, sycl::buffer<T> &buf1,
u32 len) {
275 group_cnt = group_cnt + (group_cnt % 4);
276 u32 corrected_len = group_cnt * group_size;
279 sycl::buffer<T> ret_buf(corrected_len);
285 sycl::buffer<typename ScanTile<T>::PackStorage> tile_state(group_cnt);
287 constexpr T STATE_X = 0;
288 constexpr T STATE_A = 1;
289 constexpr T STATE_P = 2;
295 q.submit([&, group_cnt, len](sycl::handler &cgh) {
296 auto dyn_id = id_gen.get_access(cgh);
298 sycl::accessor acc_in{buf1, cgh, sycl::read_only};
299 sycl::accessor acc_out{ret_buf, cgh, sycl::write_only, sycl::no_init};
300 sycl::accessor acc_tile_state{tile_state, cgh, sycl::read_write};
302 sycl::local_accessor<T, 1> local_scan_buf{1, cgh};
303 sycl::local_accessor<T, 1> local_sum{1, cgh};
305 using atomic_ref_T = sycl::atomic_ref<
307 sycl::memory_order_relaxed,
308 sycl::memory_scope_device,
309 sycl::access::address_space::global_space>;
311 cgh.parallel_for<KernelExclusiveSumAtomicSyncDecoupled_v5<T, group_size>>(
312 sycl::nd_range<1>{corrected_len, group_size}, [=](sycl::nd_item<1> id) {
313 u32 local_id =
id.get_local_id(0);
315 atomic::DynamicId<i32> group_id = dyn_id.compute_id(
id);
317 u32 group_tile_id = group_id.dyn_group_id;
318 u32 global_id = group_id.dyn_global_id;
323 T local_val = (global_id > 0 && global_id < len) ? acc_in[global_id - 1] : 0;
327 T local_scan = sycl::inclusive_scan_over_group(
328 id.get_group(), local_val, sycl::plus<T>{});
331 if (local_id == group_size - 1) {
332 local_scan_buf[0] = local_scan;
336 id.barrier(sycl::access::fence_space::local_space);
341 atomic_ref_T tile_atomic(acc_tile_state[group_tile_id]);
344 T local_group_sum = local_scan_buf[0];
346 u32 tile_ptr = group_tile_id - 1;
347 sycl::vec<T, 2> tile_state = {STATE_X, 0};
351 if (group_tile_id != 0) {
353 tile_atomic.store(sham::pack32(STATE_A, local_group_sum));
355 while (tile_state.x() != STATE_P) {
357 atomic_ref_T atomic_state(acc_tile_state[tile_ptr]);
360 tile_state = sham::unpack32(atomic_state.load());
361 }
while (tile_state.x() == STATE_X);
363 accum += tile_state.y();
369 tile_atomic.store(sham::pack32(STATE_P, accum + local_group_sum));
371 local_sum[0] = accum;
375 id.barrier(sycl::access::fence_space::local_space);
378 if (global_id < len) {
379 acc_out[global_id] = local_scan + local_sum[0];
387 template<
class T, u32 group_size>
390 template<
class T, u32 group_size>
396 group_cnt = group_cnt + (group_cnt % 4);
397 u32 corrected_len = group_cnt * group_size;
403 sycl::buffer<typename ScanTile<T>::PackStorage> tile_state(group_cnt);
405 constexpr T STATE_X = 0;
406 constexpr T STATE_A = 1;
407 constexpr T STATE_P = 2;
410 dev_sched->get_queue().q, tile_state, sham::pack32(STATE_X, T(0)));
416 T *out_ptr = ret_buf.get_write_access(depends_list);
418 sycl::event e = dev_sched->get_queue().submit(
419 depends_list, [&, group_cnt, len, in_ptr, out_ptr](sycl::handler &cgh) {
420 auto dyn_id = id_gen.get_access(cgh);
422 sycl::accessor acc_tile_state{tile_state, cgh, sycl::read_write};
424 sycl::local_accessor<T, 1> local_scan_buf{1, cgh};
425 sycl::local_accessor<T, 1> local_sum{1, cgh};
427 using atomic_ref_T = sycl::atomic_ref<
429 sycl::memory_order_relaxed,
430 sycl::memory_scope_device,
431 sycl::access::address_space::global_space>;
433 cgh.parallel_for<KernelExclusiveSumAtomicSyncDecoupled_v5_USM<T, group_size>>(
434 sycl::nd_range<1>{corrected_len, group_size}, [=](sycl::nd_item<1> id) {
435 u32 local_id =
id.get_local_id(0);
437 atomic::DynamicId<i32> group_id = dyn_id.compute_id(
id);
439 u32 group_tile_id = group_id.dyn_group_id;
440 u32 global_id = group_id.dyn_global_id;
446 = (global_id > 0 && global_id < len) ? in_ptr[global_id - 1] : 0;
450 T local_scan = sycl::inclusive_scan_over_group(
451 id.get_group(), local_val, sycl::plus<T>{});
454 if (local_id == group_size - 1) {
455 local_scan_buf[0] = local_scan;
459 id.barrier(sycl::access::fence_space::local_space);
464 atomic_ref_T tile_atomic(acc_tile_state[group_tile_id]);
467 T local_group_sum = local_scan_buf[0];
469 u32 tile_ptr = group_tile_id - 1;
470 sycl::vec<T, 2> tile_state = {STATE_X, 0};
474 if (group_tile_id != 0) {
476 tile_atomic.store(sham::pack32(STATE_A, local_group_sum));
478 while (tile_state.x() != STATE_P) {
480 atomic_ref_T atomic_state(acc_tile_state[tile_ptr]);
483 tile_state = sham::unpack32(atomic_state.load());
484 }
while (tile_state.x() == STATE_X);
486 accum += tile_state.y();
496 tile_atomic.store(sham::pack32(STATE_P, accum + local_group_sum));
498 local_sum[0] = accum;
502 id.barrier(sycl::access::fence_space::local_space);
505 if (global_id < len) {
506 out_ptr[global_id] = local_scan + local_sum[0];
511 ret_buf.complete_event_state(e);
519 template<
class T, u32 group_size>
522 template<
class T, u32 group_size>
523 void exclusive_sum_atomic_decoupled_v5_usm_in_place(
529 group_cnt = group_cnt + (group_cnt % 4);
530 u32 corrected_len = group_cnt * group_size;
538 constexpr T STATE_X = 0;
539 constexpr T STATE_A = 1;
540 constexpr T STATE_P = 2;
544 tile_state.fill(sham::pack32(STATE_X, T(0)));
550 auto acc_tile_state = tile_state.get_write_access(depends_list);
552 sycl::event e = dev_sched->get_queue().submit(
553 depends_list, [&, group_cnt, len, in_out_ptr](sycl::handler &cgh) {
554 auto dyn_id = id_gen.get_access(cgh);
558 sycl::local_accessor<T, 1> local_scan_buf{1, cgh};
559 sycl::local_accessor<T, 1> local_sum{1, cgh};
561 using atomic_ref_T = sycl::atomic_ref<
563 sycl::memory_order_relaxed,
564 sycl::memory_scope_device,
565 sycl::access::address_space::global_space>;
568 KernelExclusiveSumAtomicSyncDecoupled_v5_USM_IN_PLACE<T, group_size>>(
569 sycl::nd_range<1>{corrected_len, group_size}, [=](sycl::nd_item<1> id) {
570 u32 local_id =
id.get_local_id(0);
572 atomic::DynamicId<i32> group_id = dyn_id.compute_id(
id);
574 u32 group_tile_id = group_id.dyn_group_id;
575 u32 global_id = group_id.dyn_global_id;
580 T local_val = (global_id < len) ? in_out_ptr[global_id] : 0;
584 T local_scan = sycl::exclusive_scan_over_group(
585 id.get_group(), local_val, sycl::plus<T>{});
588 if (local_id == group_size - 1) {
589 local_scan_buf[0] = local_scan + local_val;
593 id.barrier(sycl::access::fence_space::local_space);
598 atomic_ref_T tile_atomic(acc_tile_state[group_tile_id]);
601 T local_group_sum = local_scan_buf[0];
603 u32 tile_ptr = group_tile_id - 1;
604 sycl::vec<T, 2> tile_state = {STATE_X, 0};
608 if (group_tile_id != 0) {
610 tile_atomic.store(sham::pack32(STATE_A, local_group_sum));
612 while (tile_state.x() != STATE_P) {
614 atomic_ref_T atomic_state(acc_tile_state[tile_ptr]);
617 tile_state = sham::unpack32(atomic_state.load());
618 }
while (tile_state.x() == STATE_X);
620 accum += tile_state.y();
626 tile_atomic.store(sham::pack32(STATE_P, accum + local_group_sum));
628 local_sum[0] = accum;
632 id.barrier(sycl::access::fence_space::local_space);
635 if (global_id < len) {
636 in_out_ptr[global_id] = local_scan + local_sum[0];
641 tile_state.complete_event_state(e);
644 template<
class T, u32 group_size, u32 thread_counts>
647 template<
class T, u32 group_size, u32 thread_counts>
648 sycl::buffer<T> exclusive_sum_atomic_decoupled_v6(
649 sycl::queue &q, sycl::buffer<T> &buf1,
u32 len) {
653 group_cnt = group_cnt + (group_cnt % 4);
654 u32 corrected_len = group_cnt * group_size;
657 sycl::buffer<T> ret_buf(corrected_len);
663 sycl::buffer<typename ScanTile<T>::PackStorage> tile_state(group_cnt);
665 constexpr T STATE_X = 0;
666 constexpr T STATE_A = 1;
667 constexpr T STATE_P = 2;
671 q.submit([&, group_cnt, len](sycl::handler &cgh) {
672 sycl::accessor acc_in{buf1, cgh, sycl::read_only};
673 sycl::accessor acc_out{ret_buf, cgh, sycl::write_only, sycl::no_init};
674 sycl::accessor acc_tile_state{tile_state, cgh, sycl::read_write};
676 sycl::local_accessor<T, 1> local_scan_buf{1, cgh};
677 sycl::local_accessor<T, 1> local_sum{1, cgh};
681 using atomic_ref_T = sycl::atomic_ref<
683 sycl::memory_order_relaxed,
684 sycl::memory_scope_work_group,
685 sycl::access::address_space::global_space>;
688 KernelExclusiveSumAtomicSyncDecoupled_v6<T, group_size, thread_counts>>(
689 sycl::nd_range<1>{corrected_len, group_size}, [=](sycl::nd_item<1> id) {
690 u32 local_id =
id.get_local_id(0);
691 u32 group_tile_id =
id.get_group_linear_id();
692 u32 global_id = group_tile_id * group_size + local_id;
694 auto local_group =
id.get_group();
697 T local_val = (global_id > 0 && global_id < len) ? acc_in[global_id - 1] : 0;
703 = sycl::inclusive_scan_over_group(local_group, local_val, sycl::plus<T>{});
705 if (local_id == group_size - 1) {
706 local_scan_buf[0] = local_scan;
710 id.barrier(sycl::access::fence_space::local_space);
713 static_assert(thread_counts <= group_size,
"impossible");
715 T local_group_sum = local_scan_buf[0];
721 if (group_tile_id != 0) {
723 atomic_ref_T(acc_tile_state[group_tile_id])
724 .store(sham::pack32(STATE_A, local_group_sum));
727 sycl::vec<T, 2> tile_state;
728 u32 group_tile_ptr = group_tile_id - 1;
730 bool continue_loop =
true;
734 if ((local_id < thread_counts) && (group_tile_ptr >= local_id)) {
735 atomic_ref_T atomic_state(
736 acc_tile_state[group_tile_ptr - local_id]);
739 tile_state = sham::unpack32(atomic_state.load());
740 }
while (tile_state.x() == STATE_X);
743 tile_state = {STATE_A, 0};
748 sum_state = sycl::reduce_over_group(
749 local_group, tile_state.x(), sycl::plus<T>{});
753 if (sum_state > group_size) {
756 continue_loop =
false;
758 last_p_index = sycl::reduce_over_group(
760 (tile_state.x() == STATE_P) ? (local_id) : (group_size),
765 tile_state.y() = (local_id <= last_p_index) ? tile_state.y() : 0;
771 continue_loop = (group_tile_ptr >= thread_counts);
772 group_tile_ptr -= thread_counts;
775 accum += sycl::reduce_over_group(
776 local_group, tile_state.y(), sycl::plus<T>{});
780 }
while (continue_loop);
784 atomic_ref_T(acc_tile_state[group_tile_id])
785 .store(sham::pack32(STATE_P, accum + local_group_sum));
789 if (global_id < len) {
790 acc_out[global_id] = accum + local_scan;
std::uint32_t u32
32 bit unsigned integer
std::uint64_t u64
64 bit unsigned integer
A buffer allocated in USM (Unified Shared Memory)
void complete_event_state(sycl::event e) const
Complete the event state of the buffer.
T * get_write_access(sham::EventList &depends_list, SourceLocation src_loc=SourceLocation{})
Get a read-write pointer to the buffer's data.
std::shared_ptr< DeviceScheduler > & get_dev_scheduler_ptr()
Gets the Device scheduler pointer corresponding to the held allocation.
const T * get_read_access(sham::EventList &depends_list, SourceLocation src_loc=SourceLocation{}) const
Get a read-only pointer to the buffer's data.
Class to manage a list of SYCL events.
Sycl utility to dynamically generate group ids.
void buf_fill_discard(sycl::queue &q, sycl::buffer< T > &buf, T value)
Fill a buffer with a given value (sycl::no_init mode)
constexpr u32 group_count(u32 len, u32 group_size)
Calculates the number of groups based on the length and group size.
main include file for memory algorithms
This file contains the definition for the stacktrace related functionality.