27#ifdef SHAMROCK_USE_NVTX
28 #include <nvtx3/nvtx3.hpp>
45 if (buf.size() < max_range) {
57 auto DeviceType = Device.get_info<sycl::info::device::device_type>();
59 case sycl::info::device_type::cpu :
return "CPU";
60 case sycl::info::device_type::gpu :
return "GPU";
61 case sycl::info::device_type::host :
return "HOST";
62 case sycl::info::device_type::accelerator:
return "ACCELERATOR";
63 default :
return "UNKNOWN";
76 u32 len = group_cnt * group_size;
77 return sycl::nd_range<1>{len, group_size};
80 enum ParallelForWrapMode { PARALLEL_FOR, PARALLEL_FOR_ROUND, ND_RANGE };
82#ifdef SHAMROCK_LOOP_DEFAULT_PARALLEL_FOR
83 constexpr ParallelForWrapMode default_loop_mode = PARALLEL_FOR;
86#ifdef SHAMROCK_LOOP_DEFAULT_PARALLEL_FOR_ROUND
87 constexpr ParallelForWrapMode default_loop_mode = PARALLEL_FOR_ROUND;
90#ifdef SHAMROCK_LOOP_DEFAULT_ND_RANGE
91 constexpr ParallelForWrapMode default_loop_mode = ND_RANGE;
94 constexpr u32 default_gsize = SHAMROCK_LOOP_GSIZE;
95 constexpr u32 default_gsize_2d = 16;
96 constexpr u32 default_gsize_3d = 4;
99 u32 group_size = default_gsize,
100 ParallelForWrapMode mode = default_loop_mode,
102 inline void parallel_for(sycl::handler &cgh,
u32 length,
const char *name, LambdaKernel &&ker) {
104#ifdef SHAMROCK_USE_NVTX
108 shamlog_debug_sycl_ln(
"SYCL", shambase::format(
"parallel_for {} N={}", name, length));
110 if constexpr (mode == PARALLEL_FOR) {
112 cgh.parallel_for(sycl::range<1>{length}, [=](sycl::item<1> id) {
113 ker(
id.get_linear_id());
116 }
else if constexpr (mode == PARALLEL_FOR_ROUND) {
120 cgh.parallel_for(sycl::range<1>{len}, [=](sycl::item<1> id) {
121 u64 gid =
id.get_linear_id();
128 }
else if constexpr (mode == ND_RANGE) {
130 cgh.parallel_for(
make_range(length, group_size), [=](sycl::nd_item<1>
id) {
131 u64 gid =
id.get_global_linear_id();
142#ifdef SHAMROCK_USE_NVTX
148 u32 group_size = default_gsize_2d,
149 ParallelForWrapMode mode = default_loop_mode,
151 inline void parallel_for_2d(
152 sycl::handler &cgh,
u32 length_x,
u32 length_y,
const char *name, LambdaKernel &&ker) {
154#ifdef SHAMROCK_USE_NVTX
158 shamlog_debug_sycl_ln(
159 "SYCL", shambase::format(
"parallel_for {} N={} {}", name, length_x, length_y));
161 if constexpr (mode == PARALLEL_FOR) {
163 cgh.parallel_for(sycl::range<2>{length_x, length_y}, [=](sycl::item<2> id) {
164 ker(
id.get_id(0),
id.get_id(1));
167 }
else if constexpr (mode == PARALLEL_FOR_ROUND) {
172 cgh.parallel_for(sycl::range<2>{len_x, len_y}, [=](sycl::item<2> id) {
173 if (
id.get_id(0) >= length_x ||
id.get_id(1) >= length_y)
176 ker(
id.get_id(0),
id.get_id(1));
179 }
else if constexpr (mode == ND_RANGE) {
181 sycl::nd_range<1> rx =
make_range(length_x, group_size);
182 sycl::nd_range<1> ry =
make_range(length_y, group_size);
184 sycl::range<2> tmp_s{rx.get_global_range().size(), ry.get_global_range().size()};
185 sycl::range<2> tmp_g{rx.get_group_range().size(), ry.get_group_range().size()};
187 cgh.parallel_for(sycl::nd_range<2>{tmp_s, tmp_g}, [=](sycl::nd_item<2> id) {
188 if (
id.get_global_id(0) >= length_x ||
id.get_global_id(1) >= length_y)
191 ker(
id.get_global_id(0),
id.get_global_id(1));
198#ifdef SHAMROCK_USE_NVTX
204 u32 group_size = default_gsize_3d,
205 ParallelForWrapMode mode = default_loop_mode,
207 inline void parallel_for_3d(
213 LambdaKernel &&ker) {
215#ifdef SHAMROCK_USE_NVTX
219 shamlog_debug_sycl_ln(
221 shambase::format(
"parallel_for {} N={} {} {}", name, length_x, length_y, length_z));
223 if constexpr (mode == PARALLEL_FOR) {
225 cgh.parallel_for(sycl::range<3>{length_x, length_y, length_z}, [=](sycl::item<3> id) {
226 ker(
id.get_id(0),
id.get_id(1),
id.get_id(2));
229 }
else if constexpr (mode == PARALLEL_FOR_ROUND) {
235 cgh.parallel_for(sycl::range<3>{len_x, len_y, len_z}, [=](sycl::item<3> id) {
236 if (
id.get_id(0) >= length_x ||
id.get_id(1) >= length_y
237 ||
id.get_id(2) >= length_z)
240 ker(
id.get_id(0),
id.get_id(1),
id.get_id(2));
243 }
else if constexpr (mode == ND_RANGE) {
245 sycl::nd_range<1> rx =
make_range(length_x, group_size);
246 sycl::nd_range<1> ry =
make_range(length_y, group_size);
247 sycl::nd_range<1> rz =
make_range(length_z, group_size);
249 sycl::range<3> tmp_s{
250 rx.get_global_range().size(),
251 ry.get_global_range().size(),
252 rz.get_global_range().size()};
253 sycl::range<3> tmp_g{
254 rx.get_group_range().size(),
255 ry.get_group_range().size(),
256 rz.get_group_range().size()};
258 cgh.parallel_for(sycl::nd_range<3>{tmp_s, tmp_g}, [=](sycl::nd_item<3> id) {
259 if (
id.get_global_id(0) >= length_x ||
id.get_global_id(1) >= length_y
260 ||
id.get_global_id(2) >= length_z)
263 ker(
id.get_global_id(0),
id.get_global_id(1),
id.get_global_id(2));
270#ifdef SHAMROCK_USE_NVTX
275 template<ParallelForWrapMode mode = default_loop_mode,
class LambdaKernel>
276 inline void parallel_for_gsize(
277 sycl::handler &cgh,
u32 length,
u32 group_size,
const char *name, LambdaKernel &&ker) {
279#ifdef SHAMROCK_USE_NVTX
283 if constexpr (mode == PARALLEL_FOR) {
285 cgh.parallel_for(sycl::range<1>{length}, [=](sycl::item<1> id) {
286 ker(
id.get_linear_id());
289 }
else if constexpr (mode == PARALLEL_FOR_ROUND) {
293 cgh.parallel_for(sycl::range<1>{len}, [=](sycl::item<1> id) {
294 u64 gid =
id.get_linear_id();
301 }
else if constexpr (mode == ND_RANGE) {
303 cgh.parallel_for(
make_range(length, group_size), [=](sycl::nd_item<1>
id) {
304 u64 gid =
id.get_global_linear_id();
315#ifdef SHAMROCK_USE_NVTX
320 inline void check_queue_state(sycl::queue &q, SourceLocation loc = SourceLocation()) {
321 shamlog_debug_sycl_ln(
"SYCL",
"checking queue state", loc.format_one_line());
323 shamlog_debug_sycl_ln(
"SYCL",
"checking queue state : OK !");
std::uint32_t u32
32 bit unsigned integer
std::uint64_t u64
64 bit unsigned integer
This header file contains utility functions related to exception handling in the code.
namespace for basic c++ utilities
constexpr u32 group_count(u32 len, u32 group_size)
Calculates the number of groups based on the length and group size.
sycl::nd_range< 1 > make_range(u32 length, const u32 group_size=32)
Generate a sycl nd range out of a group size and length.
void check_buffer_size(sycl::buffer< T > &buf, u64 max_range, const SourceLocation loc=SourceLocation())
check that the size of a sycl buffer is below or equal to the value of max range throw if it is not t...
std::string getDevice_type(const sycl::device &Device)
Get the Device Type Name.
ExcptTypes make_except_with_loc(std::string message, SourceLocation loc=SourceLocation{})
Create an exception with a message and a location.
void throw_unimplemented(SourceLocation loc=SourceLocation{})
Throw a std::runtime_error saying that the function is unimplemented.
This file contains the definition for the stacktrace related functionality.
provide information about the source location