Shamrock 2025.10.0
Astrophysical Code
Loading...
Searching...
No Matches
streamCompactExclScan.cpp
Go to the documentation of this file.
1// -------------------------------------------------------//
2//
3// SHAMROCK code for hydrodynamics
4// Copyright (c) 2021-2026 Timothée David--Cléris <tim.shamrock@proton.me>
5// SPDX-License-Identifier: CeCILL Free Software License Agreement v2.1
6// Shamrock is licensed under the CeCILL 2.1 License, see LICENSE for more information
7//
8// -------------------------------------------------------//
9
18#include "shambase/integer.hpp"
19#include "shambase/string.hpp"
21#include "shamalgs/memory.hpp"
22#include "shamalgs/numeric.hpp"
24#include "shamcomm/logs.hpp"
25
26class StreamCompactionAlg;
27
28namespace shamalgs::numeric::details {
29
30 std::tuple<std::optional<sycl::buffer<u32>>, u32> stream_compact_excl_scan(
31 sycl::queue &q, sycl::buffer<u32> &buf_flags, u32 len) {
32
33 if (len < 2) {
34 return stream_compact_fallback(q, buf_flags, len);
35 }
36
37 // perform the exclusive sum of the buf flag
38 sycl::buffer<u32> excl_sum = scan_exclusive(q, buf_flags, len);
39
40 // recover the end value of the sum to know the new size
41 u32 new_len = memory::extract_element(q, excl_sum, len - 1);
42
43 u32 end_flag = memory::extract_element(q, buf_flags, len - 1);
44
45 if (end_flag) {
46 new_len++;
47 }
48
49 shamlog_debug_sycl_ln("StreamCompact", "number of element : ", new_len);
50
51 if (new_len == 0) {
52 return {{}, 0};
53 }
54
55 constexpr u32 group_size = 256;
56 u32 max_len = len;
57 u32 group_cnt = shambase::group_count(len, group_size);
58 group_cnt = group_cnt + (group_cnt % 4);
59 u32 corrected_len = group_cnt * group_size;
60
61 // create the index buffer that we will return
62 sycl::buffer<u32> index_map{new_len};
63
64 q.submit([&, max_len](sycl::handler &cgh) {
65 sycl::accessor sum_vals{excl_sum, cgh, sycl::read_only};
66 sycl::accessor new_idx{index_map, cgh, sycl::write_only, sycl::no_init};
67
68 u32 last_idx = len - 1;
69 u32 last_flag = end_flag;
70
71 cgh.parallel_for<StreamCompactionAlg>(
72
73 sycl::nd_range<1>{corrected_len, group_size}, [=](sycl::nd_item<1> id) {
74 u32 local_id = id.get_local_id(0);
75 u32 group_tile_id = id.get_group_linear_id();
76 u32 idx = group_tile_id * group_size + local_id;
77
78 if (idx >= max_len)
79 return;
80
81 u32 current_val = sum_vals[idx];
82
83 bool _if1 = (idx < last_idx);
84 bool should_write
85 = (_if1) ? (current_val < sum_vals[idx + 1]) : (bool(last_flag));
86
87 if (should_write) {
88 new_idx[current_val] = idx;
89 }
90 });
91 });
92
93 return {std::move(index_map), new_len};
94 };
95
97 const sham::DeviceScheduler_ptr &sched, sham::DeviceBuffer<u32> &buf_flags, u32 len) {
98
99 if (len < 2) {
100 return stream_compact_fallback(sched, buf_flags, len);
101 }
102
103 // perform the exclusive sum of the buf flag
104 sham::DeviceBuffer<u32> excl_sum = scan_exclusive(sched, buf_flags, len);
105
106 // recover the end value of the sum to know the new size
107 u32 new_len = excl_sum.get_val_at_idx(len - 1);
108
109 u32 end_flag = buf_flags.get_val_at_idx(len - 1);
110
111 if (end_flag) {
112 new_len++;
113 }
114
115 // create the index buffer that we will return
116 sham::DeviceBuffer<u32> index_map{new_len, sched};
117
118 if (new_len > 0) {
119 // logger::raw_ln(
120 // shambase::format("len = {}, new_len = {}, end_flag = {}", len, new_len,
121 // end_flag));
123 sched->get_queue(),
124 sham::MultiRef{excl_sum},
125 sham::MultiRef{index_map},
126 len,
127 [last_idx = len - 1,
128 last_flag = end_flag](u32 idx, const u32 *sum_vals, u32 *new_idx) {
129 u32 current_val = sum_vals[idx];
130
131 bool _if1 = (idx < last_idx);
132 bool should_write
133 = (_if1) ? (current_val < sum_vals[idx + 1]) : (bool(last_flag));
134
135 // logger::raw_ln(shambase::format(
136 // "idx = {}, sum = {}, _if1 = {}, should_write = {}",
137 // idx,
138 // sum_vals[idx],
139 // _if1,
140 // should_write));
141
142 if (should_write) {
143 new_idx[current_val] = idx;
144 }
145 });
146 }
147
148 return index_map;
149 }
150
151} // namespace shamalgs::numeric::details
std::uint32_t u32
32 bit unsigned integer
A buffer allocated in USM (Unified Shared Memory)
T get_val_at_idx(size_t idx) const
Get the value at a given index in the buffer.
void kernel_call(sham::DeviceQueue &q, RefIn in, RefOut in_out, u32 n, Functor &&func, SourceLocation &&callsite=SourceLocation{})
Submit a kernel to a SYCL queue.
T extract_element(sycl::queue &q, sycl::buffer< T > &buf, u32 idx)
extract a value of a buffer
Definition memory.cpp:24
sycl::buffer< T > scan_exclusive(sycl::queue &q, sycl::buffer< T > &buf1, u32 len)
Computes the exclusive sum of elements in a SYCL buffer.
Definition numeric.cpp:35
constexpr u32 group_count(u32 len, u32 group_size)
Calculates the number of groups based on the length and group size.
Definition integer.hpp:125
std::tuple< std::optional< sycl::buffer< u32 > >, u32 > stream_compact_fallback(sycl::queue &q, sycl::buffer< u32 > &buf_flags, u32 len)
Stream compaction algorithm on fallback.
main include file for memory algorithms
std::tuple< std::optional< sycl::buffer< u32 > >, u32 > stream_compact_excl_scan(sycl::queue &q, sycl::buffer< u32 > &buf_flags, u32 len)
Stream compaction algorithm using exclusive summation.
A class that references multiple buffers or similar objects.