Shamrock 2025.10.0
Astrophysical Code
Loading...
Searching...
No Matches
exclusiveScanAtomic.hpp
Go to the documentation of this file.
1// -------------------------------------------------------//
2//
3// SHAMROCK code for hydrodynamics
4// Copyright (c) 2021-2026 Timothée David--Cléris <tim.shamrock@proton.me>
5// SPDX-License-Identifier: CeCILL Free Software License Agreement v2.1
6// Shamrock is licensed under the CeCILL 2.1 License, see LICENSE for more information
7//
8// -------------------------------------------------------//
9
10#pragma once
11
19#include "shambase/integer.hpp"
22#include "shamalgs/memory.hpp"
23#include "shambackends/math.hpp"
24#include "shambackends/sycl.hpp"
25
26namespace shamalgs::numeric::details {
27
28 template<class T, u32 group_size>
30
31 template<class T, u32 group_size>
32 sycl::buffer<T> exclusive_sum_atomic2pass(sycl::queue &q, sycl::buffer<T> &buf1, u32 len) {
33
34 u32 group_cnt = shambase::group_count(len, group_size);
35 u32 corrected_len = group_cnt * group_size;
36 // prepare the return buffer by shifting values for the exclusive sum
37 sycl::buffer<T> ret_buf(len);
38
39 q.submit([&, len](sycl::handler &cgh) {
40 sycl::accessor acc_in{buf1, cgh, sycl::read_only};
41 sycl::accessor acc_out{ret_buf, cgh, sycl::write_only, sycl::no_init};
42
43 cgh.parallel_for(sycl::range<1>{len}, [=](sycl::item<1> id) {
44 u32 thid = id.get_linear_id();
45 acc_out[id] = (thid > 0) ? acc_in[thid - 1] : 0;
46 });
47 });
48
49 // logger::raw_ln("shifted : ");
50 // shamalgs::memory::print_buf(ret_buf, len, 16,"{:4} ");
51
52 atomic::DynamicIdGenerator<i32, group_size> id_gen(q);
53
54 atomic::DeviceCounter<i32> device_count(q);
55 atomic::DeviceCounter<u32> global_summation(q);
56
57 q.submit([&, group_cnt, len](sycl::handler &cgh) {
58 sycl::accessor value_buffer{ret_buf, cgh, sycl::read_write};
59
60 auto dyn_id = id_gen.get_access(cgh);
61 auto device_counter = device_count.get_access(cgh);
62 auto global_sum = global_summation.get_access(cgh);
63
64 sycl::local_accessor<T, 1> local_scan_buf{1, cgh};
65 sycl::local_accessor<T, 1> local_sum{1, cgh};
66
67 cgh.parallel_for<KernelExclusiveSumAtomicSync<T, group_size>>(
68 sycl::nd_range<1>{corrected_len, group_size}, [=](sycl::nd_item<1> id) {
69 atomic::DynamicId<i32> group_id = dyn_id.compute_id(id);
70
71 // load from global buffer
72 T local_val;
73 if (group_id.dyn_global_id < len) {
74 local_val = value_buffer[group_id.dyn_global_id];
75 } else {
76 local_val = 0;
77 }
78
79 // local scan in the group
80 // the local sum will be in local id `group_size - 1`
81 T local_scan = sycl::inclusive_scan_over_group(
82 id.get_group(), local_val, sycl::plus<T>{});
83
84 if (id.get_local_id(0) == group_size - 1) {
85 local_scan_buf[0] = local_scan;
86 }
87
88 // sync group
89 id.barrier(sycl::access::fence_space::local_space);
90
91 // DATA PARALLEL C++: MASTERING DPC++ ... device wide synchro
92 if (group_id.is_main_thread) {
93
94 // setup device counter atomic
95 sycl::atomic_ref atomic_counter
96 = device_counter.attach_atomic<sycl::memory_order_acq_rel>();
97 sycl::atomic_ref atomic_sum
98 = global_sum.attach_atomic<sycl::memory_order_relaxed>();
99
100 // load group sum
101 T group_sum = local_scan_buf[0];
102
103 // global scan using atomic counter
104
105 if (group_id.dyn_group_id == 0) {
106
107 // store local sum
108 atomic_sum += group_sum;
109 atomic_counter++;
110 local_sum[0] = 0;
111
112 } else {
113 while (atomic_counter.load() != group_id.dyn_group_id) {
114 }
115
116 T exclusive_group_prefix_sum = atomic_sum.fetch_add(group_sum);
117
118 atomic_counter++;
119 local_sum[0] = exclusive_group_prefix_sum;
120 }
121 }
122
123 // sync
124 id.barrier(sycl::access::fence_space::local_space);
125
126 // store final result
127 if (group_id.dyn_global_id < len) {
128 value_buffer[group_id.dyn_global_id] = local_scan + local_sum[0];
129 // local_scan - local_val + local_sum[0];
130 }
131 });
132 });
133
134 return ret_buf;
135 }
136
137 template<class T, u32 group_size>
139
140 template<class T, u32 group_size>
141 sycl::buffer<T> exclusive_sum_atomic2pass_v2(sycl::queue &q, sycl::buffer<T> &buf1, u32 len) {
142
143 u32 group_cnt = shambase::group_count(len, group_size);
144 u32 corrected_len = group_cnt * group_size;
145 // prepare the return buffer by shifting values for the exclusive sum
146 sycl::buffer<T> ret_buf(len);
147
148 q.submit([&, len](sycl::handler &cgh) {
149 sycl::accessor acc_in{buf1, cgh, sycl::read_only};
150 sycl::accessor acc_out{ret_buf, cgh, sycl::write_only, sycl::no_init};
151
152 cgh.parallel_for(sycl::range<1>{len}, [=](sycl::item<1> id) {
153 u32 thid = id.get_linear_id();
154 acc_out[id] = (thid > 0) ? acc_in[thid - 1] : 0;
155 });
156 });
157
158 // logger::raw_ln("shifted : ");
159 // shamalgs::memory::print_buf(ret_buf, len, 16,"{:4} ");
160
161 // group aggregates
162 sycl::buffer<T> aggregates(group_cnt);
163
164 atomic::DynamicIdGenerator<i32, group_size> id_gen(q);
165
166 atomic::DeviceCounter<i32> device_count(q);
167
168 q.submit([&, group_cnt, len](sycl::handler &cgh) {
169 sycl::accessor value_buffer{ret_buf, cgh, sycl::read_write};
170
171 auto dyn_id = id_gen.get_access(cgh);
172 auto device_counter = device_count.get_access(cgh);
173
174 sycl::accessor acc_gsum{aggregates, cgh, sycl::read_write};
175
176 sycl::local_accessor<T, 1> local_scan_buf{1, cgh};
177 sycl::local_accessor<T, 1> local_sum{1, cgh};
178
179 cgh.parallel_for<KernelExclusiveSumAtomicSync_v2<T, group_size>>(
180 sycl::nd_range<1>{corrected_len, group_size}, [=](sycl::nd_item<1> id) {
181 atomic::DynamicId<i32> group_id = dyn_id.compute_id(id);
182
183 // load from global buffer
184 T local_val;
185 if (group_id.dyn_global_id < len) {
186 local_val = value_buffer[group_id.dyn_global_id];
187 } else {
188 local_val = 0;
189 }
190
191 // local scan in the group
192 // the local sum will be in local id `group_size - 1`
193 T local_scan = sycl::inclusive_scan_over_group(
194 id.get_group(), local_val, sycl::plus<T>{});
195
196 if (id.get_local_id(0) == group_size - 1) {
197 local_scan_buf[0] = local_scan;
198 }
199
200 // sync group
201 id.barrier(sycl::access::fence_space::local_space);
202
203 // DATA PARALLEL C++: MASTERING DPC++ ... device wide synchro
204 if (group_id.is_main_thread) {
205
206 // setup device counter atomic
207 sycl::atomic_ref atomic_counter
208 = device_counter.attach_atomic<sycl::memory_order_acq_rel>();
209
210 // load group sum
211 T group_sum = local_scan_buf[0];
212
213 // global scan using atomic counter
214
215 using atomic_ref_T = sycl::atomic_ref<
216 T,
217 sycl::memory_order_relaxed,
218 sycl::memory_scope_device,
219 sycl::access::address_space::global_space>;
220
221 if (group_id.dyn_group_id == 0) {
222
223 // store local sum
224 atomic_ref_T(acc_gsum[0]).store(group_sum);
225
226 atomic_counter++;
227 local_sum[0] = 0;
228
229 } else {
230 while (atomic_counter.load() != group_id.dyn_group_id) {
231 }
232
233 T exclusive_group_prefix_sum
234 = atomic_ref_T(acc_gsum[group_id.dyn_group_id - 1]).load();
235
236 atomic_ref_T(acc_gsum[group_id.dyn_group_id])
237 .store(exclusive_group_prefix_sum + group_sum);
238
239 atomic_counter++;
240 local_sum[0] = exclusive_group_prefix_sum;
241 }
242 }
243
244 // sync
245 id.barrier(sycl::access::fence_space::local_space);
246
247 // store final result
248 if (group_id.dyn_global_id < len) {
249 value_buffer[group_id.dyn_global_id] = local_scan + local_sum[0];
250 // local_scan - local_val + local_sum[0];
251 }
252 });
253 });
254
255 return ret_buf;
256 }
257
258 template<class T, u32 group_size>
260
261 template<class T, u32 group_size>
262 sycl::buffer<T> exclusive_sum_atomic_decoupled(sycl::queue &q, sycl::buffer<T> &buf1, u32 len) {
263
264 u32 group_cnt = shambase::group_count(len, group_size);
265 u32 corrected_len = group_cnt * group_size;
266
267 // prepare the return buffer by shifting values for the exclusive sum
268 sycl::buffer<T> ret_buf(len);
269
270 q.submit([&, len](sycl::handler &cgh) {
271 sycl::accessor acc_in{buf1, cgh, sycl::read_only};
272 sycl::accessor acc_out{ret_buf, cgh, sycl::write_only, sycl::no_init};
273
274 cgh.parallel_for(sycl::range<1>{len}, [=](sycl::item<1> id) {
275 u32 thid = id.get_linear_id();
276 acc_out[id] = (thid > 0) ? acc_in[thid - 1] : 0;
277 });
278 });
279
280 // logger::raw_ln("shifted : ");
281 // shamalgs::memory::print_buf(ret_buf, len, 16,"{:4} ");
282
283 // group aggregates
284 sycl::buffer<i32> tile_state(group_cnt);
285 sycl::buffer<T> tile_aggregates(group_cnt);
286 sycl::buffer<T> tile_incl_prefix(group_cnt);
287
288 constexpr i32 STATE_X = 0;
289 constexpr i32 STATE_A = 1;
290 constexpr i32 STATE_P = 2;
291
292 shamalgs::memory::buf_fill_discard(q, tile_state, STATE_X);
293 shamalgs::memory::buf_fill_discard(q, tile_aggregates, T(0));
294 shamalgs::memory::buf_fill_discard(q, tile_incl_prefix, T(0));
295
296 atomic::DynamicIdGenerator<i32, group_size> id_gen(q);
297
298 q.submit([&, group_cnt, len](sycl::handler &cgh) {
299 sycl::accessor acc_value{ret_buf, cgh, sycl::read_write};
300 sycl::accessor acc_tile_state{tile_state, cgh, sycl::read_write};
301 sycl::accessor acc_tile_aggregates{tile_aggregates, cgh, sycl::read_write};
302 sycl::accessor acc_tile_incl_prefix{tile_incl_prefix, cgh, sycl::read_write};
303
304 auto dyn_id = id_gen.get_access(cgh);
305
306 sycl::local_accessor<T, 1> local_scan_buf{1, cgh};
307 sycl::local_accessor<T, 1> local_sum{1, cgh};
308
309 using atomic_ref_state = sycl::atomic_ref<
310 i32,
311 sycl::memory_order_relaxed,
312 sycl::memory_scope_device,
313 sycl::access::address_space::global_space>;
314
315 using atomic_ref_T = sycl::atomic_ref<
316 T,
317 sycl::memory_order_relaxed,
318 sycl::memory_scope_device,
319 sycl::access::address_space::global_space>;
320
321 cgh.parallel_for<KernelExclusiveSumAtomicSyncDecoupled<T, group_size>>(
322 sycl::nd_range<1>{corrected_len, group_size}, [=](sycl::nd_item<1> id) {
323 atomic::DynamicId<i32> group_id = dyn_id.compute_id(id);
324
325 // load from global buffer
326 T local_val;
327 if (group_id.dyn_global_id < len) {
328 local_val = acc_value[group_id.dyn_global_id];
329 } else {
330 local_val = 0;
331 }
332
333 // local scan in the group
334 // the local sum will be in local id `group_size - 1`
335 T local_scan = sycl::inclusive_scan_over_group(
336 id.get_group(), local_val, sycl::plus<T>{});
337
338 if (id.get_local_id(0) == group_size - 1) {
339 local_scan_buf[0] = local_scan;
340 }
341
342 // sync group
343 id.barrier(sycl::access::fence_space::local_space);
344
345 // DATA PARALLEL C++: MASTERING DPC++ ... device wide synchro
346 if (group_id.is_main_thread) {
347
348 // load group sum
349 T local_group_sum = local_scan_buf[0];
350 T accum = 0;
351
352 // global scan using atomic counter
353
354 if (group_id.dyn_group_id != 0) {
355
356 atomic_ref_T(acc_tile_aggregates[group_id.dyn_group_id])
357 .store(local_group_sum);
358 atomic_ref_state(acc_tile_state[group_id.dyn_group_id]).store(STATE_A);
359
360 u32 tile_ptr = group_id.dyn_group_id - 1;
361
362 while (true) {
363 i32 tstate = atomic_ref_state(acc_tile_state[tile_ptr]).load();
364
365 if (tstate == STATE_X) {
366 continue;
367 }
368
369 if (tstate == STATE_A) {
370 accum += atomic_ref_T(acc_tile_aggregates[tile_ptr]).load();
371 }
372
373 if (tstate == STATE_P) {
374 accum += atomic_ref_T(acc_tile_incl_prefix[tile_ptr]).load();
375 break;
376 }
377
378 tile_ptr--;
379 }
380 }
381
382 atomic_ref_T(acc_tile_incl_prefix[group_id.dyn_group_id])
383 .store(accum + local_group_sum);
384 atomic_ref_state(acc_tile_state[group_id.dyn_group_id]).store(STATE_P);
385
386 local_sum[0] = accum;
387 }
388
389 // sync
390 id.barrier(sycl::access::fence_space::local_space);
391
392 // store final result
393 if (group_id.dyn_global_id < len) {
394 acc_value[group_id.dyn_global_id] = local_scan + local_sum[0];
395 // local_scan - local_val + local_sum[0];
396 }
397 });
398 });
399
400 return ret_buf;
401 }
402
403 template<class T, u32 group_size>
405
406 template<class T, u32 group_size>
407 sycl::buffer<T> exclusive_sum_atomic_decoupled_v2(
408 sycl::queue &q, sycl::buffer<T> &buf1, u32 len) {
409
410 u32 group_cnt = shambase::group_count(len, group_size);
411 u32 corrected_len = group_cnt * group_size;
412
413 // prepare the return buffer by shifting values for the exclusive sum
414 sycl::buffer<T> ret_buf(len);
415
416 q.submit([&, len](sycl::handler &cgh) {
417 sycl::accessor acc_in{buf1, cgh, sycl::read_only};
418 sycl::accessor acc_out{ret_buf, cgh, sycl::write_only, sycl::no_init};
419
420 cgh.parallel_for(sycl::range<1>{len}, [=](sycl::item<1> id) {
421 u32 thid = id.get_linear_id();
422 acc_out[id] = (thid > 0) ? acc_in[thid - 1] : 0;
423 });
424 });
425
426 // logger::raw_ln("shifted : ");
427 // shamalgs::memory::print_buf(ret_buf, len, 16,"{:4} ");
428
429 // group aggregates
430 sycl::buffer<u64> tile_state(group_cnt);
431
432 constexpr T STATE_X = 0;
433 constexpr T STATE_A = 1;
434 constexpr T STATE_P = 2;
435
436 shamalgs::memory::buf_fill_discard(q, tile_state, sham::pack32(STATE_X, T(0)));
437
438 atomic::DynamicIdGenerator<i32, group_size> id_gen(q);
439
440 q.submit([&, group_cnt, len](sycl::handler &cgh) {
441 sycl::accessor acc_value{ret_buf, cgh, sycl::read_write};
442 sycl::accessor acc_tile_state{tile_state, cgh, sycl::read_write};
443
444 auto dyn_id = id_gen.get_access(cgh);
445
446 sycl::local_accessor<T, 1> local_scan_buf{1, cgh};
447 sycl::local_accessor<T, 1> local_sum{1, cgh};
448
449 using atomic_ref_T = sycl::atomic_ref<
450 u64,
451 sycl::memory_order_relaxed,
452 sycl::memory_scope_device,
453 sycl::access::address_space::global_space>;
454
455 cgh.parallel_for<KernelExclusiveSumAtomicSyncDecoupled_v2<T, group_size>>(
456 sycl::nd_range<1>{corrected_len, group_size}, [=](sycl::nd_item<1> id) {
457 atomic::DynamicId<i32> group_id = dyn_id.compute_id(id);
458
459 // load from global buffer
460 T local_val;
461 if (group_id.dyn_global_id < len) {
462 local_val = acc_value[group_id.dyn_global_id];
463 } else {
464 local_val = 0;
465 }
466
467 // local scan in the group
468 // the local sum will be in local id `group_size - 1`
469 T local_scan = sycl::inclusive_scan_over_group(
470 id.get_group(), local_val, sycl::plus<T>{});
471
472 if (id.get_local_id(0) == group_size - 1) {
473 local_scan_buf[0] = local_scan;
474 }
475
476 // sync group
477 id.barrier(sycl::access::fence_space::local_space);
478
479 auto store = [=](u32 id, T state, T val) {
480 atomic_ref_T(acc_tile_state[id]).store(sham::pack32(state, val));
481 };
482
483 auto load = [=](u32 id) -> sycl::vec<T, 2> {
484 return sham::unpack32(atomic_ref_T(acc_tile_state[id]).load());
485 };
486
487 // DATA PARALLEL C++: MASTERING DPC++ ... device wide synchro
488 if (group_id.is_main_thread) {
489
490 // load group sum
491 T local_group_sum = local_scan_buf[0];
492 T accum = 0;
493 u32 tile_ptr = group_id.dyn_group_id - 1;
494
495 // global scan using atomic counter
496
497 if (group_id.dyn_group_id != 0) {
498
499 store(group_id.dyn_group_id, STATE_A, local_group_sum);
500
501 while (true) {
502
503 sycl::vec<T, 2> state = load(tile_ptr);
504
505 if (state.x() == STATE_X) {
506 continue;
507 }
508
509 if (state.x() == STATE_A) {
510 accum += state.y();
511 }
512
513 if (state.x() == STATE_P) {
514 accum += state.y();
515 break;
516 }
517
518 tile_ptr--;
519 }
520 }
521
522 store(group_id.dyn_group_id, STATE_P, accum + local_group_sum);
523
524 local_sum[0] = accum;
525 }
526
527 // sync
528 id.barrier(sycl::access::fence_space::local_space);
529
530 // store final result
531 if (group_id.dyn_global_id < len) {
532 acc_value[group_id.dyn_global_id] = local_scan + local_sum[0];
533 // local_scan - local_val + local_sum[0];
534 }
535 });
536 });
537
538 return ret_buf;
539 }
540
541 template<class T, u32 group_size>
543
544 template<class T, u32 group_size>
545 sycl::buffer<T> exclusive_sum_atomic_decoupled_v3(
546 sycl::queue &q, sycl::buffer<T> &buf1, u32 len) {
547
548 u32 group_cnt = shambase::group_count(len, group_size);
549 u32 corrected_len = group_cnt * group_size;
550
551 // prepare the return buffer by shifting values for the exclusive sum
552 sycl::buffer<T> ret_buf(len);
553
554 q.submit([&, len](sycl::handler &cgh) {
555 sycl::accessor acc_in{buf1, cgh, sycl::read_only};
556 sycl::accessor acc_out{ret_buf, cgh, sycl::write_only, sycl::no_init};
557
558 cgh.parallel_for(sycl::range<1>{len}, [=](sycl::item<1> id) {
559 u32 thid = id.get_linear_id();
560 acc_out[id] = (thid > 0) ? acc_in[thid - 1] : 0;
561 });
562 });
563
564 // logger::raw_ln("shifted : ");
565 // shamalgs::memory::print_buf(ret_buf, len, 16,"{:4} ");
566
567 // group aggregates
568 sycl::buffer<u64> tile_state(group_cnt);
569
570 constexpr T STATE_X = 0;
571 constexpr T STATE_A = 1;
572 constexpr T STATE_P = 2;
573
574 shamalgs::memory::buf_fill_discard(q, tile_state, sham::pack32(STATE_X, T(0)));
575
576 atomic::DynamicIdGenerator<i32, group_size> id_gen(q);
577
578 q.submit([&, group_cnt, len](sycl::handler &cgh) {
579 sycl::accessor acc_value{ret_buf, cgh, sycl::read_write};
580 sycl::accessor acc_tile_state{tile_state, cgh, sycl::read_write};
581
582 auto dyn_id = id_gen.get_access(cgh);
583
584 sycl::local_accessor<T, 1> local_scan_buf{1, cgh};
585 sycl::local_accessor<T, 1> local_sum{1, cgh};
586
587 using atomic_ref_T = sycl::atomic_ref<
588 u64,
589 sycl::memory_order_relaxed,
590 sycl::memory_scope_work_group,
591 sycl::access::address_space::global_space>;
592
593 cgh.parallel_for<KernelExclusiveSumAtomicSyncDecoupled_v3<T, group_size>>(
594 sycl::nd_range<1>{corrected_len, group_size}, [=](sycl::nd_item<1> id) {
595 atomic::DynamicId<i32> group_id = dyn_id.compute_id(id);
596
597 // load from global buffer
598 T local_val;
599 if (group_id.dyn_global_id < len) {
600 local_val = acc_value[group_id.dyn_global_id];
601 } else {
602 local_val = 0;
603 }
604
605 // local scan in the group
606 // the local sum will be in local id `group_size - 1`
607 T local_scan = sycl::inclusive_scan_over_group(
608 id.get_group(), local_val, sycl::plus<T>{});
609
610 if (id.get_local_id(0) == group_size - 1) {
611 local_scan_buf[0] = local_scan;
612 }
613
614 // sync group
615 id.barrier(sycl::access::fence_space::local_space);
616
617 auto store = [=](u32 id, T state, T val) {
618 atomic_ref_T(acc_tile_state[id]).store(sham::pack32(state, val));
619 };
620
621 auto load = [=](u32 id) -> sycl::vec<T, 2> {
622 return sham::unpack32(atomic_ref_T(acc_tile_state[id]).load());
623 };
624
625 // DATA PARALLEL C++: MASTERING DPC++ ... device wide synchro
626 if (group_id.is_main_thread) {
627
628 // load group sum
629 T local_group_sum = local_scan_buf[0];
630 T accum = 0;
631 u32 tile_ptr = group_id.dyn_group_id - 1;
632
633 // global scan using atomic counter
634
635 if (group_id.dyn_group_id != 0) {
636
637 store(group_id.dyn_group_id, STATE_A, local_group_sum);
638
639 while (true) {
640
641 sycl::vec<T, 2> state = load(tile_ptr);
642
643 if (state.x() == STATE_X) {
644 continue;
645 }
646
647 if (state.x() == STATE_A) {
648 accum += state.y();
649 }
650
651 if (state.x() == STATE_P) {
652 accum += state.y();
653 break;
654 }
655
656 tile_ptr--;
657 }
658 }
659
660 store(group_id.dyn_group_id, STATE_P, accum + local_group_sum);
661
662 local_sum[0] = accum;
663 }
664
665 // sync
666 id.barrier(sycl::access::fence_space::local_space);
667
668 // store final result
669 if (group_id.dyn_global_id < len) {
670 acc_value[group_id.dyn_global_id] = local_scan + local_sum[0];
671 // local_scan - local_val + local_sum[0];
672 }
673 });
674 });
675
676 return ret_buf;
677 }
678
679 template<class T, u32 group_size>
681
682 template<class T, u32 group_size>
683 sycl::buffer<T> exclusive_sum_atomic_decoupled_v4(
684 sycl::queue &q, sycl::buffer<T> &buf1, u32 len) {
685
686 u32 group_cnt = shambase::group_count(len, group_size);
687
688 group_cnt = group_cnt + (group_cnt % 4);
689 u32 corrected_len = group_cnt * group_size;
690
691 // prepare the return buffer by shifting values for the exclusive sum
692 sycl::buffer<T> ret_buf(corrected_len);
693
694 q.submit([&, len](sycl::handler &cgh) {
695 sycl::accessor acc_in{buf1, cgh, sycl::read_only};
696 sycl::accessor acc_out{ret_buf, cgh, sycl::write_only, sycl::no_init};
697
698 cgh.parallel_for(sycl::range<1>{corrected_len}, [=](sycl::item<1> id) {
699 u32 thid = id.get_linear_id();
700 acc_out[id] = (thid > 0 && thid < len) ? acc_in[thid - 1] : 0;
701 });
702 });
703
704 // logger::raw_ln("shifted : ");
705 // shamalgs::memory::print_buf(ret_buf, len, 16,"{:4} ");
706
707 // group aggregates
708 sycl::buffer<u64> tile_state(group_cnt);
709
710 constexpr T STATE_X = 0;
711 constexpr T STATE_A = 1;
712 constexpr T STATE_P = 2;
713
714 shamalgs::memory::buf_fill_discard(q, tile_state, sham::pack32(STATE_X, T(0)));
715
716 atomic::DynamicIdGenerator<i32, group_size> id_gen(q);
717
718 q.submit([&, group_cnt, len](sycl::handler &cgh) {
719 sycl::accessor acc_value{ret_buf, cgh, sycl::read_write};
720 sycl::accessor acc_tile_state{tile_state, cgh, sycl::read_write};
721
722 auto dyn_id = id_gen.get_access(cgh);
723
724 sycl::local_accessor<T, 1> local_scan_buf{1, cgh};
725 sycl::local_accessor<T, 1> local_sum{1, cgh};
726
727 using atomic_ref_T = sycl::atomic_ref<
728 u64,
729 sycl::memory_order_relaxed,
730 sycl::memory_scope_work_group,
731 sycl::access::address_space::global_space>;
732
733 cgh.parallel_for<KernelExclusiveSumAtomicSyncDecoupled_v4<T, group_size>>(
734 sycl::nd_range<1>{corrected_len, group_size}, [=](sycl::nd_item<1> id) {
735 atomic::DynamicId<i32> group_id = dyn_id.compute_id(id);
736
737 // load from global buffer
738 T local_val = acc_value[group_id.dyn_global_id];
739
740 // local scan in the group
741 // the local sum will be in local id `group_size - 1`
742 T local_scan = sycl::inclusive_scan_over_group(
743 id.get_group(), local_val, sycl::plus<T>{});
744
745 if (id.get_local_id(0) == group_size - 1) {
746 local_scan_buf[0] = local_scan;
747 }
748
749 // sync group
750 id.barrier(sycl::access::fence_space::local_space);
751
752 auto store = [=](u32 id, T state, T val) {
753 atomic_ref_T(acc_tile_state[id]).store(sham::pack32(state, val));
754 };
755
756 auto load = [=](u32 id) -> sycl::vec<T, 2> {
757 return sham::unpack32(atomic_ref_T(acc_tile_state[id]).load());
758 };
759
760 // DATA PARALLEL C++: MASTERING DPC++ ... device wide synchro
761 if (group_id.is_main_thread) {
762
763 // load group sum
764 T local_group_sum = local_scan_buf[0];
765 T accum = 0;
766 u32 tile_ptr = group_id.dyn_group_id - 1;
767 sycl::vec<T, 2> tile_state = {STATE_X, 0};
768
769 // global scan using atomic counter
770
771 if (group_id.dyn_group_id != 0) {
772
773 store(group_id.dyn_group_id, STATE_A, local_group_sum);
774
775 while (tile_state.x() != STATE_P) {
776
777 atomic_ref_T atomic_state(acc_tile_state[tile_ptr]);
778
779 do {
780 tile_state = sham::unpack32(atomic_state.load());
781 } while (tile_state.x() == STATE_X);
782
783 accum += tile_state.y();
784
785 tile_ptr--;
786 }
787 }
788
789 store(group_id.dyn_group_id, STATE_P, accum + local_group_sum);
790
791 local_sum[0] = accum;
792 }
793
794 // sync
795 id.barrier(sycl::access::fence_space::local_space);
796
797 // store final result
798 acc_value[group_id.dyn_global_id] = local_scan + local_sum[0];
799 });
800 });
801
802 return ret_buf;
803 }
804
805 template<class T, u32 group_size>
807
808 template<class T, u32 group_size>
809 sycl::buffer<T> exclusive_sum_sycl_jointalg(sycl::queue &q, sycl::buffer<T> &buf1, u32 len) {
810
811 u32 group_cnt = shambase::group_count(len, group_size);
812
813 group_cnt = group_cnt + (group_cnt % 4);
814 u32 corrected_len = group_cnt * group_size;
815
816 // prepare the return buffer by shifting values for the exclusive sum
817 sycl::buffer<T> ret_buf(corrected_len);
818 sycl::buffer<T> ret_buf2(corrected_len);
819
820 q.submit([&, len](sycl::handler &cgh) {
821 sycl::accessor acc_in{buf1, cgh, sycl::read_only};
822 sycl::accessor acc_out{ret_buf, cgh, sycl::write_only, sycl::no_init};
823
824 cgh.parallel_for(sycl::range<1>{corrected_len}, [=](sycl::item<1> id) {
825 u32 thid = id.get_linear_id();
826 acc_out[id] = (thid > 0 && thid < len) ? acc_in[thid - 1] : 0;
827 });
828 });
829
830 // logger::raw_ln("shifted : ");
831 // shamalgs::memory::print_buf(ret_buf, len, 16,"{:4} ");
832
833 // group aggregates
834 sycl::buffer<u64> tile_state(group_cnt);
835
836 constexpr T STATE_X = 0;
837 constexpr T STATE_A = 1;
838 constexpr T STATE_P = 2;
839
840 shamalgs::memory::buf_fill_discard(q, tile_state, sham::pack32(STATE_X, T(0)));
841
842 q.submit([&, group_cnt, len](sycl::handler &cgh) {
843 sycl::accessor acc_in{ret_buf, cgh, sycl::read_write};
844 sycl::accessor acc_out{ret_buf2, cgh, sycl::read_write};
845
846 cgh.parallel_for<KernelExclusivesum_sycl_jointalg<T, group_size>>(
847 sycl::nd_range<1>{corrected_len, group_size}, [=](sycl::nd_item<1> id) {
848 T *first = &(acc_in[0]);
849 T *last = first + acc_in.size();
850
851 T *first_out = &(acc_out[0]);
852
853 T excl_val;
854 sycl::joint_inclusive_scan(
855 id.get_group(), first, last, first_out, sycl::plus<T>{});
856 });
857 });
858
859 return ret_buf2;
860 }
861
862} // namespace shamalgs::numeric::details
std::uint32_t u32
32 bit unsigned integer
std::uint64_t u64
64 bit unsigned integer
std::int32_t i32
32 bit integer
void buf_fill_discard(sycl::queue &q, sycl::buffer< T > &buf, T value)
Fill a buffer with a given value (sycl::no_init mode)
Definition memory.hpp:159
constexpr u32 group_count(u32 len, u32 group_size)
Calculates the number of groups based on the length and group size.
Definition integer.hpp:125
main include file for memory algorithms