Shamrock 2025.10.0
Astrophysical Code
Loading...
Searching...
No Matches
scanDecoupledLookback.hpp
Go to the documentation of this file.
1// -------------------------------------------------------//
2//
3// SHAMROCK code for hydrodynamics
4// Copyright (c) 2021-2026 Timothée David--Cléris <tim.shamrock@proton.me>
5// SPDX-License-Identifier: CeCILL Free Software License Agreement v2.1
6// Shamrock is licensed under the CeCILL 2.1 License, see LICENSE for more information
7//
8// -------------------------------------------------------//
9
10#pragma once
11
19#include "shambase/integer.hpp"
23#include "shamalgs/memory.hpp"
25#include "shambackends/math.hpp"
26#include "shambackends/sycl.hpp"
27
28namespace shamalgs::numeric::details {
29
30 template<class T>
31 class ScanTile {
32 public:
33 static constexpr T STATE_X = 0;
34 static constexpr T STATE_A = 1;
35 static constexpr T STATE_P = 2;
36
37 using PackStorage = u64;
38
39 sycl::vec<T, 2> state;
40
41 inline static ScanTile invalid() { return ScanTile{{STATE_X, 0}}; }
42
43 inline bool has_prefix_available() { return state.x() == STATE_P; }
44
45 inline T get_prefix() { return state.y(); }
46
47 inline static ScanTile unpack(PackStorage s) { return ScanTile{sham::unpack32(s)}; }
48
49 inline static PackStorage pack(T a, T b) { return sham::pack32(a, b); }
50
51 inline bool has_no_prefix() { return state.x() != STATE_P; }
52
53 inline bool is_invalid() { return state.x() == STATE_X; }
54 };
55
57 public:
58 static constexpr u32 STATE_X = 0;
59 static constexpr u32 STATE_A = 1;
60 static constexpr u32 STATE_P = 2;
61
62 using PackStorage = u32;
63
64 sycl::vec<u32, 2> state;
65
66 inline static ScanTile30bitint invalid() { return ScanTile30bitint{{STATE_X, 0}}; }
67
68 inline bool has_prefix_available() { return state.x() == STATE_P; }
69
70 inline u32 get_prefix() { return state.y(); }
71
72 inline static ScanTile30bitint unpack(PackStorage s) {
73
74 constexpr u32 mask = (1U << 30U) - 1U;
75
76 return ScanTile30bitint{sycl::vec<u32, 2>{s >> 30U, s & mask}};
77 }
78
79 inline static PackStorage pack(u32 a, u32 b) { return (a << 30U) + b; }
80
81 inline bool has_no_prefix() { return state.x() != STATE_P; }
82
83 inline bool is_invalid() { return state.x() == STATE_X; }
84 };
85
86 enum DecoupledLookBackPolicy { Standard, Parallelized };
87
88 template<class T, u32 group_size, DecoupledLookBackPolicy policy, class Tile>
89 class ScanDecoupledLoockBack;
90
91 template<class T, u32 group_size, DecoupledLookBackPolicy policy, class Tile>
93 public:
94 sycl::accessor<typename Tile::PackStorage, 1, sycl::access::mode::read_write>
95 acc_tile_state;
96
97 sycl::local_accessor<T, 1> local_scan_buf;
98 sycl::local_accessor<T, 1> local_sum;
99
100 u32 group_count;
101
102 using atomic_ref_T = sycl::atomic_ref<
103 typename Tile::PackStorage,
104 sycl::memory_order_relaxed,
105 sycl::memory_scope_work_group,
106 sycl::access::address_space::global_space>;
107
109 sycl::handler &cgh,
111 u32 group_count)
112 : acc_tile_state{scan.tile_state, cgh, sycl::read_write}, local_scan_buf{1, cgh},
113 local_sum{1, cgh}, group_count(group_count) {}
114
115 template<class InputGetter, class OutputSetter>
116 inline void decoupled_lookback_scan(
117 sycl::nd_item<1> id,
118 const u32 local_id,
119 const u32 group_tile_id,
120 InputGetter input,
121 OutputSetter out,
122 u32 slice_id = 0) const {
123
124 u32 pointer_offset = slice_id * group_count;
125
126 if (local_id == 0) {
127
128 atomic_ref_T tile_atomic(acc_tile_state[group_tile_id + pointer_offset]);
129
130 // load group sum
131 T local_group_sum = input();
132 T accum = 0;
133 u32 tile_ptr = group_tile_id - 1;
134 Tile tile_state = Tile::invalid();
135
136 // global scan using atomic counter
137
138 if (group_tile_id != 0) {
139
140 tile_atomic.store(Tile::pack(Tile::STATE_A, local_group_sum));
141
142 while (tile_state.has_no_prefix()) {
143
144 atomic_ref_T atomic_state(acc_tile_state[tile_ptr + pointer_offset]);
145
146 do {
147 tile_state = Tile::unpack(atomic_state.load());
148 } while (tile_state.is_invalid());
149
150 accum += tile_state.get_prefix();
151
152 tile_ptr--;
153 }
154 }
155
156 tile_atomic.store(Tile::pack(Tile::STATE_P, accum + local_group_sum));
157
158 out(accum);
159 }
160
161 // sync
162 id.barrier(sycl::access::fence_space::local_space);
163 }
164
165 inline T scan(
166 sycl::nd_item<1> id,
167 const u32 local_id,
168 const u32 group_tile_id,
169 const T input,
170 u32 slice_id = 0) const {
171
172 // local scan in the group
173 // the local sum will be in local id `group_size - 1`
174 T local_scan = sycl::inclusive_scan_over_group(id.get_group(), input, sycl::plus<T>{});
175
176 // can be removed if i change the index in the look back ?
177 if (local_id == group_size - 1) {
178 local_scan_buf[0] = local_scan;
179 }
180
181 // sync group
182 id.barrier(sycl::access::fence_space::local_space);
183
184 decoupled_lookback_scan(
185 id,
186 local_id,
187 group_tile_id,
188 [=]() {
189 return local_scan_buf[0];
190 },
191 [=](T accum) {
192 local_sum[0] = accum;
193 },
194 slice_id);
195
196 return local_scan + local_sum[0];
197 }
198 };
199
200 template<class T, u32 group_size, DecoupledLookBackPolicy policy, class Tile>
202 public:
203 u32 slice_count;
204 u32 group_count;
205
206 sycl::buffer<typename Tile::PackStorage> tile_state;
207
208 ScanDecoupledLoockBack(sycl::queue &q, u32 group_count, u32 slice_count = 1)
209 : slice_count(slice_count), group_count(group_count),
210 tile_state(group_count * slice_count) {
211
212 shamalgs::memory::buf_fill_discard(q, tile_state, Tile::pack(Tile::STATE_X, T(0)));
213 }
214
215 using atomic_ref_T = sycl::atomic_ref<
216 typename Tile::PackStorage,
217 sycl::memory_order_relaxed,
218 sycl::memory_scope_device,
219 sycl::access::address_space::global_space>;
220
222 sycl::handler &cgh) {
224 cgh, *this, group_count};
225 }
226 };
227
228 template<class T, u32 group_size>
230
231 template<class T, u32 group_size>
232 void exclusive_sum_in_place_atomic_decoupled_v5(
233 sycl::queue &q, sycl::buffer<T> &buf1, u32 len) {
234 u32 group_cnt = shambase::group_count(len, group_size);
235
236 group_cnt = group_cnt + (group_cnt % 4);
237 u32 corrected_len = group_cnt * group_size;
238
239 // group aggregates
241
242 q.submit([&, group_cnt, len](sycl::handler &cgh) {
243 sycl::accessor acc_value{buf1, cgh, sycl::read_write};
244
245 auto scanop = dlookbackscan.get_access(cgh);
246
248 sycl::nd_range<1>{corrected_len, group_size}, [=](sycl::nd_item<1> id) {
249 u32 local_id = id.get_local_id(0);
250 u32 group_tile_id = id.get_group_linear_id();
251 u32 global_id = group_tile_id * group_size + local_id;
252
253 // load from global buffer
254 T local_val = (global_id > 0 && global_id < len) ? acc_value[global_id - 1] : 0;
255
256 T scanned_value = scanop.scan(id, local_id, group_tile_id, local_val);
257
258 // store final result
259 if (global_id < len) {
260 acc_value[global_id] = scanned_value;
261 }
262 });
263 });
264 }
265
266 template<class T, u32 group_size>
268
269 template<class T, u32 group_size>
270 sycl::buffer<T> exclusive_sum_atomic_decoupled_v5(
271 sycl::queue &q, sycl::buffer<T> &buf1, u32 len) {
272
273 u32 group_cnt = shambase::group_count(len, group_size);
274
275 group_cnt = group_cnt + (group_cnt % 4);
276 u32 corrected_len = group_cnt * group_size;
277
278 // prepare the return buffer by shifting values for the exclusive sum
279 sycl::buffer<T> ret_buf(corrected_len);
280
281 // logger::raw_ln("shifted : ");
282 // shamalgs::memory::print_buf(ret_buf, len, 16,"{:4} ");
283
284 // group aggregates
285 sycl::buffer<typename ScanTile<T>::PackStorage> tile_state(group_cnt);
286
287 constexpr T STATE_X = 0;
288 constexpr T STATE_A = 1;
289 constexpr T STATE_P = 2;
290
291 shamalgs::memory::buf_fill_discard(q, tile_state, sham::pack32(STATE_X, T(0)));
292
294
295 q.submit([&, group_cnt, len](sycl::handler &cgh) {
296 auto dyn_id = id_gen.get_access(cgh);
297
298 sycl::accessor acc_in{buf1, cgh, sycl::read_only};
299 sycl::accessor acc_out{ret_buf, cgh, sycl::write_only, sycl::no_init};
300 sycl::accessor acc_tile_state{tile_state, cgh, sycl::read_write};
301
302 sycl::local_accessor<T, 1> local_scan_buf{1, cgh};
303 sycl::local_accessor<T, 1> local_sum{1, cgh};
304
305 using atomic_ref_T = sycl::atomic_ref<
306 u64,
307 sycl::memory_order_relaxed,
308 sycl::memory_scope_device,
309 sycl::access::address_space::global_space>;
310
311 cgh.parallel_for<KernelExclusiveSumAtomicSyncDecoupled_v5<T, group_size>>(
312 sycl::nd_range<1>{corrected_len, group_size}, [=](sycl::nd_item<1> id) {
313 u32 local_id = id.get_local_id(0);
314
315 atomic::DynamicId<i32> group_id = dyn_id.compute_id(id);
316
317 u32 group_tile_id = group_id.dyn_group_id;
318 u32 global_id = group_id.dyn_global_id;
319 // u32 group_tile_id = id.get_group_linear_id();
320 // u32 global_id = group_tile_id * group_size + local_id;
321
322 // load from global buffer
323 T local_val = (global_id > 0 && global_id < len) ? acc_in[global_id - 1] : 0;
324
325 // local scan in the group
326 // the local sum will be in local id `group_size - 1`
327 T local_scan = sycl::inclusive_scan_over_group(
328 id.get_group(), local_val, sycl::plus<T>{});
329
330 // can be removed if i change the index in the look back ?
331 if (local_id == group_size - 1) {
332 local_scan_buf[0] = local_scan;
333 }
334
335 // sync group
336 id.barrier(sycl::access::fence_space::local_space);
337
338 // DATA PARALLEL C++: MASTERING DPC++ ... device wide synchro
339 if (local_id == 0) {
340
341 atomic_ref_T tile_atomic(acc_tile_state[group_tile_id]);
342
343 // load group sum
344 T local_group_sum = local_scan_buf[0];
345 T accum = 0;
346 u32 tile_ptr = group_tile_id - 1;
347 sycl::vec<T, 2> tile_state = {STATE_X, 0};
348
349 // global scan using atomic counter
350
351 if (group_tile_id != 0) {
352
353 tile_atomic.store(sham::pack32(STATE_A, local_group_sum));
354
355 while (tile_state.x() != STATE_P) {
356
357 atomic_ref_T atomic_state(acc_tile_state[tile_ptr]);
358
359 do {
360 tile_state = sham::unpack32(atomic_state.load());
361 } while (tile_state.x() == STATE_X);
362
363 accum += tile_state.y();
364
365 tile_ptr--;
366 }
367 }
368
369 tile_atomic.store(sham::pack32(STATE_P, accum + local_group_sum));
370
371 local_sum[0] = accum;
372 }
373
374 // sync
375 id.barrier(sycl::access::fence_space::local_space);
376
377 // store final result
378 if (global_id < len) {
379 acc_out[global_id] = local_scan + local_sum[0];
380 }
381 });
382 });
383
384 return ret_buf;
385 }
386
387 template<class T, u32 group_size>
389
390 template<class T, u32 group_size>
391 sham::DeviceBuffer<T> exclusive_sum_atomic_decoupled_v5_usm(
392 sham::DeviceScheduler_ptr dev_sched, sham::DeviceBuffer<T, sham::device> &buf1, u32 len) {
393
394 u32 group_cnt = shambase::group_count(len, group_size);
395
396 group_cnt = group_cnt + (group_cnt % 4);
397 u32 corrected_len = group_cnt * group_size;
398
399 // prepare the return buffer by shifting values for the exclusive sum
400 sham::DeviceBuffer<T> ret_buf(corrected_len, dev_sched);
401
402 // group aggregates
403 sycl::buffer<typename ScanTile<T>::PackStorage> tile_state(group_cnt);
404
405 constexpr T STATE_X = 0;
406 constexpr T STATE_A = 1;
407 constexpr T STATE_P = 2;
408
410 dev_sched->get_queue().q, tile_state, sham::pack32(STATE_X, T(0)));
411
412 atomic::DynamicIdGenerator<i32, group_size> id_gen(dev_sched->get_queue().q);
413
414 sham::EventList depends_list;
415 const T *in_ptr = buf1.get_read_access(depends_list);
416 T *out_ptr = ret_buf.get_write_access(depends_list);
417
418 sycl::event e = dev_sched->get_queue().submit(
419 depends_list, [&, group_cnt, len, in_ptr, out_ptr](sycl::handler &cgh) {
420 auto dyn_id = id_gen.get_access(cgh);
421
422 sycl::accessor acc_tile_state{tile_state, cgh, sycl::read_write};
423
424 sycl::local_accessor<T, 1> local_scan_buf{1, cgh};
425 sycl::local_accessor<T, 1> local_sum{1, cgh};
426
427 using atomic_ref_T = sycl::atomic_ref<
428 u64,
429 sycl::memory_order_relaxed,
430 sycl::memory_scope_device,
431 sycl::access::address_space::global_space>;
432
433 cgh.parallel_for<KernelExclusiveSumAtomicSyncDecoupled_v5_USM<T, group_size>>(
434 sycl::nd_range<1>{corrected_len, group_size}, [=](sycl::nd_item<1> id) {
435 u32 local_id = id.get_local_id(0);
436
437 atomic::DynamicId<i32> group_id = dyn_id.compute_id(id);
438
439 u32 group_tile_id = group_id.dyn_group_id;
440 u32 global_id = group_id.dyn_global_id;
441 // u32 group_tile_id = id.get_group_linear_id();
442 // u32 global_id = group_tile_id * group_size + local_id;
443
444 // load from global buffer
445 T local_val
446 = (global_id > 0 && global_id < len) ? in_ptr[global_id - 1] : 0;
447
448 // local scan in the group
449 // the local sum will be in local id `group_size - 1`
450 T local_scan = sycl::inclusive_scan_over_group(
451 id.get_group(), local_val, sycl::plus<T>{});
452
453 // can be removed if i change the index in the look back ?
454 if (local_id == group_size - 1) {
455 local_scan_buf[0] = local_scan;
456 }
457
458 // sync group
459 id.barrier(sycl::access::fence_space::local_space);
460
461 // DATA PARALLEL C++: MASTERING DPC++ ... device wide synchro
462 if (local_id == 0) {
463
464 atomic_ref_T tile_atomic(acc_tile_state[group_tile_id]);
465
466 // load group sum
467 T local_group_sum = local_scan_buf[0];
468 T accum = 0;
469 u32 tile_ptr = group_tile_id - 1;
470 sycl::vec<T, 2> tile_state = {STATE_X, 0};
471
472 // global scan using atomic counter
473
474 if (group_tile_id != 0) {
475
476 tile_atomic.store(sham::pack32(STATE_A, local_group_sum));
477
478 while (tile_state.x() != STATE_P) {
479
480 atomic_ref_T atomic_state(acc_tile_state[tile_ptr]);
481
482 do {
483 tile_state = sham::unpack32(atomic_state.load());
484 } while (tile_state.x() == STATE_X);
485
486 accum += tile_state.y();
487
488 // if it overflows, tile_ptr == 0, but in that case we can only
489 // reach this line if the acc_tile_state[0] is in P state.
490 // Therefore we will never perform an access again with tile_ptr
491 // so it is safe
492 tile_ptr--;
493 }
494 }
495
496 tile_atomic.store(sham::pack32(STATE_P, accum + local_group_sum));
497
498 local_sum[0] = accum;
499 }
500
501 // sync
502 id.barrier(sycl::access::fence_space::local_space);
503
504 // store final result
505 if (global_id < len) {
506 out_ptr[global_id] = local_scan + local_sum[0];
507 }
508 });
509 });
510 buf1.complete_event_state(e);
511 ret_buf.complete_event_state(e);
512
513 // Without this the returned buffer is wrong
514 ret_buf.resize(len);
515
516 return ret_buf;
517 }
518
519 template<class T, u32 group_size>
521
522 template<class T, u32 group_size>
523 void exclusive_sum_atomic_decoupled_v5_usm_in_place(
525 StackEntry stack_loc{};
526
527 u32 group_cnt = shambase::group_count(len, group_size);
528
529 group_cnt = group_cnt + (group_cnt % 4);
530 u32 corrected_len = group_cnt * group_size;
531
532 auto dev_sched = buf1.get_dev_scheduler_ptr();
533
534 // group aggregates
535 sham::DeviceBuffer<typename ScanTile<T>::PackStorage> tile_state(group_cnt, dev_sched);
536 // sycl::buffer<typename ScanTile<T>::PackStorage> tile_state(group_cnt);
537
538 constexpr T STATE_X = 0;
539 constexpr T STATE_A = 1;
540 constexpr T STATE_P = 2;
541
542 // shamalgs::memory::buf_fill_discard(
543 // dev_sched->get_queue().q, tile_state, sham::pack32(STATE_X, T(0)));
544 tile_state.fill(sham::pack32(STATE_X, T(0)));
545
546 atomic::DynamicIdGenerator<i32, group_size> id_gen(dev_sched->get_queue().q);
547
548 sham::EventList depends_list;
549 T *in_out_ptr = buf1.get_write_access(depends_list);
550 auto acc_tile_state = tile_state.get_write_access(depends_list);
551
552 sycl::event e = dev_sched->get_queue().submit(
553 depends_list, [&, group_cnt, len, in_out_ptr](sycl::handler &cgh) {
554 auto dyn_id = id_gen.get_access(cgh);
555
556 // sycl::accessor acc_tile_state{tile_state, cgh, sycl::read_write};
557
558 sycl::local_accessor<T, 1> local_scan_buf{1, cgh};
559 sycl::local_accessor<T, 1> local_sum{1, cgh};
560
561 using atomic_ref_T = sycl::atomic_ref<
562 u64,
563 sycl::memory_order_relaxed,
564 sycl::memory_scope_device,
565 sycl::access::address_space::global_space>;
566
567 cgh.parallel_for<
568 KernelExclusiveSumAtomicSyncDecoupled_v5_USM_IN_PLACE<T, group_size>>(
569 sycl::nd_range<1>{corrected_len, group_size}, [=](sycl::nd_item<1> id) {
570 u32 local_id = id.get_local_id(0);
571
572 atomic::DynamicId<i32> group_id = dyn_id.compute_id(id);
573
574 u32 group_tile_id = group_id.dyn_group_id;
575 u32 global_id = group_id.dyn_global_id;
576 // u32 group_tile_id = id.get_group_linear_id();
577 // u32 global_id = group_tile_id * group_size + local_id;
578
579 // load from global buffer
580 T local_val = (global_id < len) ? in_out_ptr[global_id] : 0;
581
582 // local scan in the group
583 // the local sum will be in local id `group_size - 1`
584 T local_scan = sycl::exclusive_scan_over_group(
585 id.get_group(), local_val, sycl::plus<T>{});
586
587 // can be removed if i change the index in the look back ?
588 if (local_id == group_size - 1) {
589 local_scan_buf[0] = local_scan + local_val;
590 }
591
592 // sync group
593 id.barrier(sycl::access::fence_space::local_space);
594
595 // DATA PARALLEL C++: MASTERING DPC++ ... device wide synchro
596 if (local_id == 0) {
597
598 atomic_ref_T tile_atomic(acc_tile_state[group_tile_id]);
599
600 // load group sum
601 T local_group_sum = local_scan_buf[0];
602 T accum = 0;
603 u32 tile_ptr = group_tile_id - 1;
604 sycl::vec<T, 2> tile_state = {STATE_X, 0};
605
606 // global scan using atomic counter
607
608 if (group_tile_id != 0) {
609
610 tile_atomic.store(sham::pack32(STATE_A, local_group_sum));
611
612 while (tile_state.x() != STATE_P) {
613
614 atomic_ref_T atomic_state(acc_tile_state[tile_ptr]);
615
616 do {
617 tile_state = sham::unpack32(atomic_state.load());
618 } while (tile_state.x() == STATE_X);
619
620 accum += tile_state.y();
621
622 tile_ptr--;
623 }
624 }
625
626 tile_atomic.store(sham::pack32(STATE_P, accum + local_group_sum));
627
628 local_sum[0] = accum;
629 }
630
631 // sync
632 id.barrier(sycl::access::fence_space::local_space);
633
634 // store final result
635 if (global_id < len) {
636 in_out_ptr[global_id] = local_scan + local_sum[0];
637 }
638 });
639 });
640 buf1.complete_event_state(e);
641 tile_state.complete_event_state(e);
642 }
643
644 template<class T, u32 group_size, u32 thread_counts>
646
647 template<class T, u32 group_size, u32 thread_counts>
648 sycl::buffer<T> exclusive_sum_atomic_decoupled_v6(
649 sycl::queue &q, sycl::buffer<T> &buf1, u32 len) {
650
651 u32 group_cnt = shambase::group_count(len, group_size);
652
653 group_cnt = group_cnt + (group_cnt % 4);
654 u32 corrected_len = group_cnt * group_size;
655
656 // prepare the return buffer by shifting values for the exclusive sum
657 sycl::buffer<T> ret_buf(corrected_len);
658
659 // logger::raw_ln("shifted : ");
660 // shamalgs::memory::print_buf(ret_buf, len, 16,"{:4} ");
661
662 // group aggregates
663 sycl::buffer<typename ScanTile<T>::PackStorage> tile_state(group_cnt);
664
665 constexpr T STATE_X = 0;
666 constexpr T STATE_A = 1;
667 constexpr T STATE_P = 2;
668
669 shamalgs::memory::buf_fill_discard(q, tile_state, sham::pack32(STATE_X, T(0)));
670
671 q.submit([&, group_cnt, len](sycl::handler &cgh) {
672 sycl::accessor acc_in{buf1, cgh, sycl::read_only};
673 sycl::accessor acc_out{ret_buf, cgh, sycl::write_only, sycl::no_init};
674 sycl::accessor acc_tile_state{tile_state, cgh, sycl::read_write};
675
676 sycl::local_accessor<T, 1> local_scan_buf{1, cgh};
677 sycl::local_accessor<T, 1> local_sum{1, cgh};
678
679 // sycl::stream dump (4096, 1024, cgh);
680
681 using atomic_ref_T = sycl::atomic_ref<
682 u64,
683 sycl::memory_order_relaxed,
684 sycl::memory_scope_work_group,
685 sycl::access::address_space::global_space>;
686
687 cgh.parallel_for<
688 KernelExclusiveSumAtomicSyncDecoupled_v6<T, group_size, thread_counts>>(
689 sycl::nd_range<1>{corrected_len, group_size}, [=](sycl::nd_item<1> id) {
690 u32 local_id = id.get_local_id(0);
691 u32 group_tile_id = id.get_group_linear_id();
692 u32 global_id = group_tile_id * group_size + local_id;
693
694 auto local_group = id.get_group();
695
696 // load from global buffer
697 T local_val = (global_id > 0 && global_id < len) ? acc_in[global_id - 1] : 0;
698 ;
699
700 // local scan in the group
701 // the local sum will be in local id `group_size - 1`
702 T local_scan
703 = sycl::inclusive_scan_over_group(local_group, local_val, sycl::plus<T>{});
704
705 if (local_id == group_size - 1) {
706 local_scan_buf[0] = local_scan;
707 }
708
709 // sync group
710 id.barrier(sycl::access::fence_space::local_space);
711
712 // parallelized lookback
713 static_assert(thread_counts <= group_size, "impossible");
714
715 T local_group_sum = local_scan_buf[0];
716 T accum = 0;
717
718 T sum_state;
719 u32 last_p_index;
720
721 if (group_tile_id != 0) {
722 if (local_id == 0) {
723 atomic_ref_T(acc_tile_state[group_tile_id])
724 .store(sham::pack32(STATE_A, local_group_sum));
725 }
726
727 sycl::vec<T, 2> tile_state;
728 u32 group_tile_ptr = group_tile_id - 1;
729
730 bool continue_loop = true;
731
732 do {
733
734 if ((local_id < thread_counts) && (group_tile_ptr >= local_id)) {
735 atomic_ref_T atomic_state(
736 acc_tile_state[group_tile_ptr - local_id]);
737
738 do {
739 tile_state = sham::unpack32(atomic_state.load());
740 } while (tile_state.x() == STATE_X);
741
742 } else {
743 tile_state = {STATE_A, 0};
744 }
745
746 // if(group_tile_id == 25) dump << "ps : " << tile_state << "\n";
747
748 sum_state = sycl::reduce_over_group(
749 local_group, tile_state.x(), sycl::plus<T>{});
750
751 // if(group_tile_id == 25) dump << "ss : " << sum_state << "\n";
752
753 if (sum_state > group_size) {
754 // there is a P
755
756 continue_loop = false;
757
758 last_p_index = sycl::reduce_over_group(
759 local_group,
760 (tile_state.x() == STATE_P) ? (local_id) : (group_size),
761 sycl::minimum<T>{});
762
763 // if(group_tile_id == 25) dump << "lp : " << last_p_index << "\n";
764
765 tile_state.y() = (local_id <= last_p_index) ? tile_state.y() : 0;
766
767 // if(group_tile_id == 25) dump << "ts : " << tile_state << "\n";
768
769 } else {
770 // there is only A's
771 continue_loop = (group_tile_ptr >= thread_counts);
772 group_tile_ptr -= thread_counts;
773 }
774
775 accum += sycl::reduce_over_group(
776 local_group, tile_state.y(), sycl::plus<T>{});
777
778 // if(group_tile_id == 25) dump << "as : " << accum << "\n";
779
780 } while (continue_loop);
781 }
782
783 if (local_id == 0) {
784 atomic_ref_T(acc_tile_state[group_tile_id])
785 .store(sham::pack32(STATE_P, accum + local_group_sum));
786 }
787
788 // store final result
789 if (global_id < len) {
790 acc_out[global_id] = accum + local_scan;
791 }
792 });
793 });
794
795 return ret_buf;
796 }
797
798} // namespace shamalgs::numeric::details
std::uint32_t u32
32 bit unsigned integer
std::uint64_t u64
64 bit unsigned integer
A buffer allocated in USM (Unified Shared Memory)
void complete_event_state(sycl::event e) const
Complete the event state of the buffer.
T * get_write_access(sham::EventList &depends_list, SourceLocation src_loc=SourceLocation{})
Get a read-write pointer to the buffer's data.
std::shared_ptr< DeviceScheduler > & get_dev_scheduler_ptr()
Gets the Device scheduler pointer corresponding to the held allocation.
const T * get_read_access(sham::EventList &depends_list, SourceLocation src_loc=SourceLocation{}) const
Get a read-only pointer to the buffer's data.
Class to manage a list of SYCL events.
Definition EventList.hpp:31
Sycl utility to dynamically generate group ids.
void buf_fill_discard(sycl::queue &q, sycl::buffer< T > &buf, T value)
Fill a buffer with a given value (sycl::no_init mode)
Definition memory.hpp:159
constexpr u32 group_count(u32 len, u32 group_size)
Calculates the number of groups based on the length and group size.
Definition integer.hpp:125
main include file for memory algorithms
This file contains the definition for the stacktrace related functionality.