27template<
class T, u32 work_group_size>
30template<
class T, u32 work_group_size>
33template<
class T, u32 work_group_size>
36namespace shamalgs::reduction::details {
38#ifdef SYCL2020_FEATURE_GROUP_REDUCTION
40 template<
class T, u32 work_group_size>
41 struct GroupReduction {
43 static T sum(sycl::queue &q, sycl::buffer<T> &buf1,
u32 start_id,
u32 end_id);
45 static T min(sycl::queue &q, sycl::buffer<T> &buf1,
u32 start_id,
u32 end_id);
47 static T max(sycl::queue &q, sycl::buffer<T> &buf1,
u32 start_id,
u32 end_id);
50 template<
class T, u32 work_group_size>
51 inline T GroupReduction<T, work_group_size>::sum(
52 sycl::queue &q, sycl::buffer<T> &buf1,
u32 start_id,
u32 end_id) {
53 u32 len = end_id - start_id;
55 sycl::buffer<T> buf_int(len);
57 shamalgs::memory::write_with_offset_into(q, buf_int, buf1, start_id, len);
60 u32 remaining_val = len;
61 while (len / cur_slice_sz > work_group_size * 8) {
65 q.submit([&](sycl::handler &cgh) {
66 sycl::accessor global_mem{buf_int, cgh, sycl::read_write};
68 u32 slice_read_size = cur_slice_sz;
69 u32 slice_write_size = cur_slice_sz * work_group_size;
73 exec_range, [=](sycl::nd_item<1> item) {
74 u64 lid = item.get_local_id(0);
75 u64 group_tile_id = item.get_group_linear_id();
76 u64 gid = group_tile_id * work_group_size + lid;
78 u64 iread = gid * slice_read_size;
79 u64 iwrite = group_tile_id * slice_write_size;
81 T val_read = (iread < max_id) ? global_mem[iread]
82 :
shambase::VectorProperties<T>::get_zero();
84 T local_red = sham::sum_over_group(item.get_group(), val_read);
88 global_mem[iwrite] = local_red;
93 cur_slice_sz *= work_group_size;
94 remaining_val = exec_range.get_group_range().size();
97 sycl::buffer<T> recov{remaining_val};
100 q.submit([&, remaining_val](sycl::handler &cgh) {
101 sycl::accessor compute_buf{buf_int, cgh, sycl::read_only};
102 sycl::accessor result{recov, cgh, sycl::write_only, sycl::no_init};
104 u32 slice_read_size = cur_slice_sz;
106 cgh.parallel_for(exec_range, [=](sycl::nd_item<1> item) {
107 u64 lid = item.get_local_id(0);
108 u64 group_tile_id = item.get_group_linear_id();
109 u64 gid = group_tile_id * work_group_size + lid;
111 u64 iread = gid * slice_read_size;
113 if (gid >= remaining_val) {
117 result[gid] = compute_buf[iread];
123 sycl::host_accessor acc{recov, sycl::read_only};
124 for (
u64 i = 0; i < remaining_val; i++) {
132 template<
class T, u32 work_group_size>
133 inline T GroupReduction<T, work_group_size>::min(
134 sycl::queue &q, sycl::buffer<T> &buf1,
u32 start_id,
u32 end_id) {
135 u32 len = end_id - start_id;
137 sycl::buffer<T> buf_int(len);
139 shamalgs::memory::write_with_offset_into(q, buf_int, buf1, start_id, len);
141 u32 cur_slice_sz = 1;
142 u32 remaining_val = len;
143 while (len / cur_slice_sz > work_group_size * 8) {
147 q.submit([&](sycl::handler &cgh) {
148 sycl::accessor global_mem{buf_int, cgh, sycl::read_write};
150 u32 slice_read_size = cur_slice_sz;
151 u32 slice_write_size = cur_slice_sz * work_group_size;
155 exec_range, [=](sycl::nd_item<1> item) {
156 u64 lid = item.get_local_id(0);
157 u64 group_tile_id = item.get_group_linear_id();
158 u64 gid = group_tile_id * work_group_size + lid;
160 u64 iread = gid * slice_read_size;
161 u64 iwrite = group_tile_id * slice_write_size;
163 T val_read = (iread < max_id) ? global_mem[iread]
164 :
shambase::VectorProperties<T>::get_max();
166 T local_red = sham::min_over_group(item.get_group(), val_read);
170 global_mem[iwrite] = local_red;
175 cur_slice_sz *= work_group_size;
176 remaining_val = exec_range.get_group_range().size();
179 sycl::buffer<T> recov{remaining_val};
182 q.submit([&, remaining_val](sycl::handler &cgh) {
183 sycl::accessor compute_buf{buf_int, cgh, sycl::read_only};
184 sycl::accessor result{recov, cgh, sycl::write_only, sycl::no_init};
186 u32 slice_read_size = cur_slice_sz;
188 cgh.parallel_for(exec_range, [=](sycl::nd_item<1> item) {
189 u64 lid = item.get_local_id(0);
190 u64 group_tile_id = item.get_group_linear_id();
191 u64 gid = group_tile_id * work_group_size + lid;
193 u64 iread = gid * slice_read_size;
195 if (gid >= remaining_val) {
199 result[gid] = compute_buf[iread];
205 sycl::host_accessor acc{recov, sycl::read_only};
206 for (
u64 i = 0; i < remaining_val; i++) {
207 ret = sham::min(acc[i], ret);
214 template<
class T, u32 work_group_size>
215 inline T GroupReduction<T, work_group_size>::max(
216 sycl::queue &q, sycl::buffer<T> &buf1,
u32 start_id,
u32 end_id) {
217 u32 len = end_id - start_id;
219 sycl::buffer<T> buf_int(len);
221 shamalgs::memory::write_with_offset_into(q, buf_int, buf1, start_id, len);
223 u32 cur_slice_sz = 1;
224 u32 remaining_val = len;
225 while (len / cur_slice_sz > work_group_size * 8) {
229 q.submit([&](sycl::handler &cgh) {
230 sycl::accessor global_mem{buf_int, cgh, sycl::read_write};
232 u32 slice_read_size = cur_slice_sz;
233 u32 slice_write_size = cur_slice_sz * work_group_size;
237 exec_range, [=](sycl::nd_item<1> item) {
238 u64 lid = item.get_local_id(0);
239 u64 group_tile_id = item.get_group_linear_id();
240 u64 gid = group_tile_id * work_group_size + lid;
242 u64 iread = gid * slice_read_size;
243 u64 iwrite = group_tile_id * slice_write_size;
245 T val_read = (iread < max_id) ? global_mem[iread]
246 :
shambase::VectorProperties<T>::get_min();
248 T local_red = sham::max_over_group(item.get_group(), val_read);
252 global_mem[iwrite] = local_red;
257 cur_slice_sz *= work_group_size;
258 remaining_val = exec_range.get_group_range().size();
261 sycl::buffer<T> recov{remaining_val};
264 q.submit([&, remaining_val](sycl::handler &cgh) {
265 sycl::accessor compute_buf{buf_int, cgh, sycl::read_only};
266 sycl::accessor result{recov, cgh, sycl::write_only, sycl::no_init};
268 u32 slice_read_size = cur_slice_sz;
270 cgh.parallel_for(exec_range, [=](sycl::nd_item<1> item) {
271 u64 lid = item.get_local_id(0);
272 u64 group_tile_id = item.get_group_linear_id();
273 u64 gid = group_tile_id * work_group_size + lid;
275 u64 iread = gid * slice_read_size;
277 if (gid >= remaining_val) {
281 result[gid] = compute_buf[iread];
287 sycl::host_accessor acc{recov, sycl::read_only};
288 for (
u64 i = 0; i < remaining_val; i++) {
289 ret = sham::max(acc[i], ret);
std::uint32_t u32
32 bit unsigned integer
std::uint64_t u64
64 bit unsigned integer
namespace for basic c++ utilities
sycl::nd_range< 1 > make_range(u32 length, const u32 group_size=32)
Generate a sycl nd range out of a group size and length.
main include file for memory algorithms