Shamrock 2025.10.0
Astrophysical Code
Loading...
Searching...
No Matches
sycl_utils.hpp
Go to the documentation of this file.
1// -------------------------------------------------------//
2//
3// SHAMROCK code for hydrodynamics
4// Copyright (c) 2021-2026 Timothée David--Cléris <tim.shamrock@proton.me>
5// SPDX-License-Identifier: CeCILL Free Software License Agreement v2.1
6// Shamrock is licensed under the CeCILL 2.1 License, see LICENSE for more information
7//
8// -------------------------------------------------------//
9
10#pragma once
11
20#include "shambase/integer.hpp"
22#include "shambackends/sycl.hpp"
23#include "shambackends/vec.hpp"
24#include "shamcomm/logs.hpp"
25#include <stdexcept>
26
27#ifdef SHAMROCK_USE_NVTX
28 #include <nvtx3/nvtx3.hpp>
29#endif
30
31namespace shambase {
32
42 template<class T>
44 sycl::buffer<T> &buf, u64 max_range, const SourceLocation loc = SourceLocation()) {
45 if (buf.size() < max_range) {
46 throw make_except_with_loc<std::invalid_argument>("buffer is too small", loc);
47 }
48 }
49
56 inline std::string getDevice_type(const sycl::device &Device) {
57 auto DeviceType = Device.get_info<sycl::info::device::device_type>();
58 switch (DeviceType) {
59 case sycl::info::device_type::cpu : return "CPU";
60 case sycl::info::device_type::gpu : return "GPU";
61 case sycl::info::device_type::host : return "HOST";
62 case sycl::info::device_type::accelerator: return "ACCELERATOR";
63 default : return "UNKNOWN";
64 }
65 }
66
74 inline sycl::nd_range<1> make_range(u32 length, const u32 group_size = 32) {
75 u32 group_cnt = shambase::group_count(length, group_size);
76 u32 len = group_cnt * group_size;
77 return sycl::nd_range<1>{len, group_size};
78 }
79
80 enum ParallelForWrapMode { PARALLEL_FOR, PARALLEL_FOR_ROUND, ND_RANGE };
81
82#ifdef SHAMROCK_LOOP_DEFAULT_PARALLEL_FOR
83 constexpr ParallelForWrapMode default_loop_mode = PARALLEL_FOR;
84#endif
85
86#ifdef SHAMROCK_LOOP_DEFAULT_PARALLEL_FOR_ROUND
87 constexpr ParallelForWrapMode default_loop_mode = PARALLEL_FOR_ROUND;
88#endif
89
90#ifdef SHAMROCK_LOOP_DEFAULT_ND_RANGE
91 constexpr ParallelForWrapMode default_loop_mode = ND_RANGE;
92#endif
93
94 constexpr u32 default_gsize = SHAMROCK_LOOP_GSIZE;
95 constexpr u32 default_gsize_2d = 16;
96 constexpr u32 default_gsize_3d = 4;
97
98 template<
99 u32 group_size = default_gsize,
100 ParallelForWrapMode mode = default_loop_mode,
101 class LambdaKernel>
102 inline void parallel_for(sycl::handler &cgh, u32 length, const char *name, LambdaKernel &&ker) {
103
104#ifdef SHAMROCK_USE_NVTX
105 nvtxRangePush(name);
106#endif
107
108 shamlog_debug_sycl_ln("SYCL", shambase::format("parallel_for {} N={}", name, length));
109
110 if constexpr (mode == PARALLEL_FOR) {
111
112 cgh.parallel_for(sycl::range<1>{length}, [=](sycl::item<1> id) {
113 ker(id.get_linear_id());
114 });
115
116 } else if constexpr (mode == PARALLEL_FOR_ROUND) {
117
118 u32 len = shambase::group_count(length, group_size) * group_size;
119
120 cgh.parallel_for(sycl::range<1>{len}, [=](sycl::item<1> id) {
121 u64 gid = id.get_linear_id();
122 if (gid >= length)
123 return;
124
125 ker(gid);
126 });
127
128 } else if constexpr (mode == ND_RANGE) {
129
130 cgh.parallel_for(make_range(length, group_size), [=](sycl::nd_item<1> id) {
131 u64 gid = id.get_global_linear_id();
132 if (gid >= length)
133 return;
134
135 ker(gid);
136 });
137
138 } else {
140 }
141
142#ifdef SHAMROCK_USE_NVTX
143 nvtxRangePop();
144#endif
145 }
146
147 template<
148 u32 group_size = default_gsize_2d,
149 ParallelForWrapMode mode = default_loop_mode,
150 class LambdaKernel>
151 inline void parallel_for_2d(
152 sycl::handler &cgh, u32 length_x, u32 length_y, const char *name, LambdaKernel &&ker) {
153
154#ifdef SHAMROCK_USE_NVTX
155 nvtxRangePush(name);
156#endif
157
158 shamlog_debug_sycl_ln(
159 "SYCL", shambase::format("parallel_for {} N={} {}", name, length_x, length_y));
160
161 if constexpr (mode == PARALLEL_FOR) {
162
163 cgh.parallel_for(sycl::range<2>{length_x, length_y}, [=](sycl::item<2> id) {
164 ker(id.get_id(0), id.get_id(1));
165 });
166
167 } else if constexpr (mode == PARALLEL_FOR_ROUND) {
168
169 u32 len_x = shambase::group_count(length_x, group_size) * group_size;
170 u32 len_y = shambase::group_count(length_y, group_size) * group_size;
171
172 cgh.parallel_for(sycl::range<2>{len_x, len_y}, [=](sycl::item<2> id) {
173 if (id.get_id(0) >= length_x || id.get_id(1) >= length_y)
174 return;
175
176 ker(id.get_id(0), id.get_id(1));
177 });
178
179 } else if constexpr (mode == ND_RANGE) {
180
181 sycl::nd_range<1> rx = make_range(length_x, group_size);
182 sycl::nd_range<1> ry = make_range(length_y, group_size);
183
184 sycl::range<2> tmp_s{rx.get_global_range().size(), ry.get_global_range().size()};
185 sycl::range<2> tmp_g{rx.get_group_range().size(), ry.get_group_range().size()};
186
187 cgh.parallel_for(sycl::nd_range<2>{tmp_s, tmp_g}, [=](sycl::nd_item<2> id) {
188 if (id.get_global_id(0) >= length_x || id.get_global_id(1) >= length_y)
189 return;
190
191 ker(id.get_global_id(0), id.get_global_id(1));
192 });
193
194 } else {
196 }
197
198#ifdef SHAMROCK_USE_NVTX
199 nvtxRangePop();
200#endif
201 }
202
203 template<
204 u32 group_size = default_gsize_3d,
205 ParallelForWrapMode mode = default_loop_mode,
206 class LambdaKernel>
207 inline void parallel_for_3d(
208 sycl::handler &cgh,
212 const char *name,
213 LambdaKernel &&ker) {
214
215#ifdef SHAMROCK_USE_NVTX
216 nvtxRangePush(name);
217#endif
218
219 shamlog_debug_sycl_ln(
220 "SYCL",
221 shambase::format("parallel_for {} N={} {} {}", name, length_x, length_y, length_z));
222
223 if constexpr (mode == PARALLEL_FOR) {
224
225 cgh.parallel_for(sycl::range<3>{length_x, length_y, length_z}, [=](sycl::item<3> id) {
226 ker(id.get_id(0), id.get_id(1), id.get_id(2));
227 });
228
229 } else if constexpr (mode == PARALLEL_FOR_ROUND) {
230
231 u32 len_x = shambase::group_count(length_x, group_size) * group_size;
232 u32 len_y = shambase::group_count(length_y, group_size) * group_size;
233 u32 len_z = shambase::group_count(length_z, group_size) * group_size;
234
235 cgh.parallel_for(sycl::range<3>{len_x, len_y, len_z}, [=](sycl::item<3> id) {
236 if (id.get_id(0) >= length_x || id.get_id(1) >= length_y
237 || id.get_id(2) >= length_z)
238 return;
239
240 ker(id.get_id(0), id.get_id(1), id.get_id(2));
241 });
242
243 } else if constexpr (mode == ND_RANGE) {
244
245 sycl::nd_range<1> rx = make_range(length_x, group_size);
246 sycl::nd_range<1> ry = make_range(length_y, group_size);
247 sycl::nd_range<1> rz = make_range(length_z, group_size);
248
249 sycl::range<3> tmp_s{
250 rx.get_global_range().size(),
251 ry.get_global_range().size(),
252 rz.get_global_range().size()};
253 sycl::range<3> tmp_g{
254 rx.get_group_range().size(),
255 ry.get_group_range().size(),
256 rz.get_group_range().size()};
257
258 cgh.parallel_for(sycl::nd_range<3>{tmp_s, tmp_g}, [=](sycl::nd_item<3> id) {
259 if (id.get_global_id(0) >= length_x || id.get_global_id(1) >= length_y
260 || id.get_global_id(2) >= length_z)
261 return;
262
263 ker(id.get_global_id(0), id.get_global_id(1), id.get_global_id(2));
264 });
265
266 } else {
268 }
269
270#ifdef SHAMROCK_USE_NVTX
271 nvtxRangePop();
272#endif
273 }
274
275 template<ParallelForWrapMode mode = default_loop_mode, class LambdaKernel>
276 inline void parallel_for_gsize(
277 sycl::handler &cgh, u32 length, u32 group_size, const char *name, LambdaKernel &&ker) {
278
279#ifdef SHAMROCK_USE_NVTX
280 nvtxRangePush(name);
281#endif
282
283 if constexpr (mode == PARALLEL_FOR) {
284
285 cgh.parallel_for(sycl::range<1>{length}, [=](sycl::item<1> id) {
286 ker(id.get_linear_id());
287 });
288
289 } else if constexpr (mode == PARALLEL_FOR_ROUND) {
290
291 u32 len = shambase::group_count(length, group_size) * group_size;
292
293 cgh.parallel_for(sycl::range<1>{len}, [=](sycl::item<1> id) {
294 u64 gid = id.get_linear_id();
295 if (gid >= length)
296 return;
297
298 ker(gid);
299 });
300
301 } else if constexpr (mode == ND_RANGE) {
302
303 cgh.parallel_for(make_range(length, group_size), [=](sycl::nd_item<1> id) {
304 u64 gid = id.get_global_linear_id();
305 if (gid >= length)
306 return;
307
308 ker(gid);
309 });
310
311 } else {
313 }
314
315#ifdef SHAMROCK_USE_NVTX
316 nvtxRangePop();
317#endif
318 }
319
320 inline void check_queue_state(sycl::queue &q, SourceLocation loc = SourceLocation()) {
321 shamlog_debug_sycl_ln("SYCL", "checking queue state", loc.format_one_line());
322 q.wait_and_throw();
323 shamlog_debug_sycl_ln("SYCL", "checking queue state : OK !");
324 }
325
326} // namespace shambase
std::uint32_t u32
32 bit unsigned integer
std::uint64_t u64
64 bit unsigned integer
This header file contains utility functions related to exception handling in the code.
namespace for basic c++ utilities
constexpr u32 group_count(u32 len, u32 group_size)
Calculates the number of groups based on the length and group size.
Definition integer.hpp:125
void throw_with_loc(std::string message, SourceLocation loc=SourceLocation{})
Throw an exception and append the source location to it.
sycl::nd_range< 1 > make_range(u32 length, const u32 group_size=32)
Generate a sycl nd range out of a group size and length.
void check_buffer_size(sycl::buffer< T > &buf, u64 max_range, const SourceLocation loc=SourceLocation())
check that the size of a sycl buffer is below or equal to the value of max range throw if it is not t...
std::string getDevice_type(const sycl::device &Device)
Get the Device Type Name.
void throw_unimplemented(SourceLocation loc=SourceLocation{})
Throw a std::runtime_error saying that the function is unimplemented.
This file contains the definition for the stacktrace related functionality.
provide information about the source location