Shamrock 2025.10.0
Astrophysical Code
Loading...
Searching...
No Matches
groupReduction.hpp
Go to the documentation of this file.
1// -------------------------------------------------------//
2//
3// SHAMROCK code for hydrodynamics
4// Copyright (c) 2021-2026 Timothée David--Cléris <tim.shamrock@proton.me>
5// SPDX-License-Identifier: CeCILL Free Software License Agreement v2.1
6// Shamrock is licensed under the CeCILL 2.1 License, see LICENSE for more information
7//
8// -------------------------------------------------------//
9
10#pragma once
11
20#include "shamalgs/memory.hpp"
22#include "shambackends/math.hpp"
23#include "shambackends/sycl.hpp"
25#include "shambackends/vec.hpp"
26
27template<class T, u32 work_group_size>
29
30template<class T, u32 work_group_size>
32
33template<class T, u32 work_group_size>
35
36namespace shamalgs::reduction::details {
37
38#ifdef SYCL2020_FEATURE_GROUP_REDUCTION
39
40 template<class T, u32 work_group_size>
41 struct GroupReduction {
42
43 static T sum(sycl::queue &q, sycl::buffer<T> &buf1, u32 start_id, u32 end_id);
44
45 static T min(sycl::queue &q, sycl::buffer<T> &buf1, u32 start_id, u32 end_id);
46
47 static T max(sycl::queue &q, sycl::buffer<T> &buf1, u32 start_id, u32 end_id);
48 };
49
50 template<class T, u32 work_group_size>
51 inline T GroupReduction<T, work_group_size>::sum(
52 sycl::queue &q, sycl::buffer<T> &buf1, u32 start_id, u32 end_id) {
53 u32 len = end_id - start_id;
54
55 sycl::buffer<T> buf_int(len);
56
57 shamalgs::memory::write_with_offset_into(q, buf_int, buf1, start_id, len);
58
59 u32 cur_slice_sz = 1;
60 u32 remaining_val = len;
61 while (len / cur_slice_sz > work_group_size * 8) {
62
63 sycl::nd_range<1> exec_range = shambase::make_range(remaining_val, work_group_size);
64
65 q.submit([&](sycl::handler &cgh) {
66 sycl::accessor global_mem{buf_int, cgh, sycl::read_write};
67
68 u32 slice_read_size = cur_slice_sz;
69 u32 slice_write_size = cur_slice_sz * work_group_size;
70 u32 max_id = len;
71
73 exec_range, [=](sycl::nd_item<1> item) {
74 u64 lid = item.get_local_id(0);
75 u64 group_tile_id = item.get_group_linear_id();
76 u64 gid = group_tile_id * work_group_size + lid;
77
78 u64 iread = gid * slice_read_size;
79 u64 iwrite = group_tile_id * slice_write_size;
80
81 T val_read = (iread < max_id) ? global_mem[iread]
82 : shambase::VectorProperties<T>::get_zero();
83
84 T local_red = sham::sum_over_group(item.get_group(), val_read);
85
86 // can be removed if i change the index in the look back ?
87 if (lid == 0) {
88 global_mem[iwrite] = local_red;
89 }
90 });
91 });
92
93 cur_slice_sz *= work_group_size;
94 remaining_val = exec_range.get_group_range().size();
95 }
96
97 sycl::buffer<T> recov{remaining_val};
98
99 sycl::nd_range<1> exec_range = shambase::make_range(remaining_val, work_group_size);
100 q.submit([&, remaining_val](sycl::handler &cgh) {
101 sycl::accessor compute_buf{buf_int, cgh, sycl::read_only};
102 sycl::accessor result{recov, cgh, sycl::write_only, sycl::no_init};
103
104 u32 slice_read_size = cur_slice_sz;
105
106 cgh.parallel_for(exec_range, [=](sycl::nd_item<1> item) {
107 u64 lid = item.get_local_id(0);
108 u64 group_tile_id = item.get_group_linear_id();
109 u64 gid = group_tile_id * work_group_size + lid;
110
111 u64 iread = gid * slice_read_size;
112
113 if (gid >= remaining_val) {
114 return;
115 }
116
117 result[gid] = compute_buf[iread];
118 });
119 });
120
122 {
123 sycl::host_accessor acc{recov, sycl::read_only};
124 for (u64 i = 0; i < remaining_val; i++) {
125 ret += acc[i];
126 }
127 }
128
129 return ret;
130 }
131
132 template<class T, u32 work_group_size>
133 inline T GroupReduction<T, work_group_size>::min(
134 sycl::queue &q, sycl::buffer<T> &buf1, u32 start_id, u32 end_id) {
135 u32 len = end_id - start_id;
136
137 sycl::buffer<T> buf_int(len);
138
139 shamalgs::memory::write_with_offset_into(q, buf_int, buf1, start_id, len);
140
141 u32 cur_slice_sz = 1;
142 u32 remaining_val = len;
143 while (len / cur_slice_sz > work_group_size * 8) {
144
145 sycl::nd_range<1> exec_range = shambase::make_range(remaining_val, work_group_size);
146
147 q.submit([&](sycl::handler &cgh) {
148 sycl::accessor global_mem{buf_int, cgh, sycl::read_write};
149
150 u32 slice_read_size = cur_slice_sz;
151 u32 slice_write_size = cur_slice_sz * work_group_size;
152 u32 max_id = len;
153
155 exec_range, [=](sycl::nd_item<1> item) {
156 u64 lid = item.get_local_id(0);
157 u64 group_tile_id = item.get_group_linear_id();
158 u64 gid = group_tile_id * work_group_size + lid;
159
160 u64 iread = gid * slice_read_size;
161 u64 iwrite = group_tile_id * slice_write_size;
162
163 T val_read = (iread < max_id) ? global_mem[iread]
164 : shambase::VectorProperties<T>::get_max();
165
166 T local_red = sham::min_over_group(item.get_group(), val_read);
167
168 // can be removed if i change the index in the look back ?
169 if (lid == 0) {
170 global_mem[iwrite] = local_red;
171 }
172 });
173 });
174
175 cur_slice_sz *= work_group_size;
176 remaining_val = exec_range.get_group_range().size();
177 }
178
179 sycl::buffer<T> recov{remaining_val};
180
181 sycl::nd_range<1> exec_range = shambase::make_range(remaining_val, work_group_size);
182 q.submit([&, remaining_val](sycl::handler &cgh) {
183 sycl::accessor compute_buf{buf_int, cgh, sycl::read_only};
184 sycl::accessor result{recov, cgh, sycl::write_only, sycl::no_init};
185
186 u32 slice_read_size = cur_slice_sz;
187
188 cgh.parallel_for(exec_range, [=](sycl::nd_item<1> item) {
189 u64 lid = item.get_local_id(0);
190 u64 group_tile_id = item.get_group_linear_id();
191 u64 gid = group_tile_id * work_group_size + lid;
192
193 u64 iread = gid * slice_read_size;
194
195 if (gid >= remaining_val) {
196 return;
197 }
198
199 result[gid] = compute_buf[iread];
200 });
201 });
202
204 {
205 sycl::host_accessor acc{recov, sycl::read_only};
206 for (u64 i = 0; i < remaining_val; i++) {
207 ret = sham::min(acc[i], ret);
208 }
209 }
210
211 return ret;
212 }
213
214 template<class T, u32 work_group_size>
215 inline T GroupReduction<T, work_group_size>::max(
216 sycl::queue &q, sycl::buffer<T> &buf1, u32 start_id, u32 end_id) {
217 u32 len = end_id - start_id;
218
219 sycl::buffer<T> buf_int(len);
220
221 shamalgs::memory::write_with_offset_into(q, buf_int, buf1, start_id, len);
222
223 u32 cur_slice_sz = 1;
224 u32 remaining_val = len;
225 while (len / cur_slice_sz > work_group_size * 8) {
226
227 sycl::nd_range<1> exec_range = shambase::make_range(remaining_val, work_group_size);
228
229 q.submit([&](sycl::handler &cgh) {
230 sycl::accessor global_mem{buf_int, cgh, sycl::read_write};
231
232 u32 slice_read_size = cur_slice_sz;
233 u32 slice_write_size = cur_slice_sz * work_group_size;
234 u32 max_id = len;
235
237 exec_range, [=](sycl::nd_item<1> item) {
238 u64 lid = item.get_local_id(0);
239 u64 group_tile_id = item.get_group_linear_id();
240 u64 gid = group_tile_id * work_group_size + lid;
241
242 u64 iread = gid * slice_read_size;
243 u64 iwrite = group_tile_id * slice_write_size;
244
245 T val_read = (iread < max_id) ? global_mem[iread]
246 : shambase::VectorProperties<T>::get_min();
247
248 T local_red = sham::max_over_group(item.get_group(), val_read);
249
250 // can be removed if i change the index in the look back ?
251 if (lid == 0) {
252 global_mem[iwrite] = local_red;
253 }
254 });
255 });
256
257 cur_slice_sz *= work_group_size;
258 remaining_val = exec_range.get_group_range().size();
259 }
260
261 sycl::buffer<T> recov{remaining_val};
262
263 sycl::nd_range<1> exec_range = shambase::make_range(remaining_val, work_group_size);
264 q.submit([&, remaining_val](sycl::handler &cgh) {
265 sycl::accessor compute_buf{buf_int, cgh, sycl::read_only};
266 sycl::accessor result{recov, cgh, sycl::write_only, sycl::no_init};
267
268 u32 slice_read_size = cur_slice_sz;
269
270 cgh.parallel_for(exec_range, [=](sycl::nd_item<1> item) {
271 u64 lid = item.get_local_id(0);
272 u64 group_tile_id = item.get_group_linear_id();
273 u64 gid = group_tile_id * work_group_size + lid;
274
275 u64 iread = gid * slice_read_size;
276
277 if (gid >= remaining_val) {
278 return;
279 }
280
281 result[gid] = compute_buf[iread];
282 });
283 });
284
286 {
287 sycl::host_accessor acc{recov, sycl::read_only};
288 for (u64 i = 0; i < remaining_val; i++) {
289 ret = sham::max(acc[i], ret);
290 }
291 }
292
293 return ret;
294 }
295#endif
296
297} // namespace shamalgs::reduction::details
std::uint32_t u32
32 bit unsigned integer
std::uint64_t u64
64 bit unsigned integer
namespace for basic c++ utilities
sycl::nd_range< 1 > make_range(u32 length, const u32 group_size=32)
Generate a sycl nd range out of a group size and length.
main include file for memory algorithms