27#ifdef SHAMROCK_USE_NVTX
28 #include <nvtx3/nvtx3.hpp>
45 if (buf.size() < max_range) {
57 auto DeviceType = Device.get_info<sycl::info::device::device_type>();
59 case sycl::info::device_type::cpu :
return "CPU";
60 case sycl::info::device_type::gpu :
return "GPU";
61 case sycl::info::device_type::host :
return "HOST";
62 case sycl::info::device_type::accelerator:
return "ACCELERATOR";
77 return sycl::nd_range<1>{
len, group_size};
80 enum ParallelForWrapMode { PARALLEL_FOR, PARALLEL_FOR_ROUND, ND_RANGE };
82#ifdef SHAMROCK_LOOP_DEFAULT_PARALLEL_FOR
86#ifdef SHAMROCK_LOOP_DEFAULT_PARALLEL_FOR_ROUND
90#ifdef SHAMROCK_LOOP_DEFAULT_ND_RANGE
95 constexpr u32 default_gsize_2d = 16;
96 constexpr u32 default_gsize_3d = 4;
99 u32 group_size = default_gsize,
104#ifdef SHAMROCK_USE_NVTX
108 shamlog_debug_sycl_ln(
"SYCL", shambase::format(
"parallel_for {} N={}", name, length));
110 if constexpr (mode == PARALLEL_FOR) {
112 cgh.parallel_for(sycl::range<1>{length}, [=](sycl::item<1>
id) {
116 }
else if constexpr (mode == PARALLEL_FOR_ROUND) {
120 cgh.parallel_for(sycl::range<1>{
len}, [=](sycl::item<1>
id) {
121 u64 gid =
id.get_linear_id();
128 }
else if constexpr (mode == ND_RANGE) {
130 cgh.parallel_for(
make_range(length, group_size), [=](sycl::nd_item<1>
id) {
131 u64 gid =
id.get_global_linear_id();
142#ifdef SHAMROCK_USE_NVTX
148 u32 group_size = default_gsize_2d,
151 inline void parallel_for_2d(
154#ifdef SHAMROCK_USE_NVTX
158 shamlog_debug_sycl_ln(
159 "SYCL", shambase::format(
"parallel_for {} N={} {}", name,
length_x,
length_y));
161 if constexpr (mode == PARALLEL_FOR) {
167 }
else if constexpr (mode == PARALLEL_FOR_ROUND) {
179 }
else if constexpr (mode == ND_RANGE) {
184 sycl::range<2>
tmp_s{
rx.get_global_range().size(),
ry.get_global_range().size()};
185 sycl::range<2>
tmp_g{
rx.get_group_range().size(),
ry.get_group_range().size()};
187 cgh.parallel_for(sycl::nd_range<2>{
tmp_s,
tmp_g}, [=](sycl::nd_item<2>
id) {
198#ifdef SHAMROCK_USE_NVTX
204 u32 group_size = default_gsize_3d,
207 inline void parallel_for_3d(
215#ifdef SHAMROCK_USE_NVTX
219 shamlog_debug_sycl_ln(
223 if constexpr (mode == PARALLEL_FOR) {
229 }
else if constexpr (mode == PARALLEL_FOR_ROUND) {
243 }
else if constexpr (mode == ND_RANGE) {
249 sycl::range<3>
tmp_s{
250 rx.get_global_range().size(),
251 ry.get_global_range().size(),
252 rz.get_global_range().size()};
253 sycl::range<3>
tmp_g{
254 rx.get_group_range().size(),
255 ry.get_group_range().size(),
256 rz.get_group_range().size()};
258 cgh.parallel_for(sycl::nd_range<3>{
tmp_s,
tmp_g}, [=](sycl::nd_item<3>
id) {
270#ifdef SHAMROCK_USE_NVTX
275 template<ParallelForWrapMode mode = default_loop_mode,
class LambdaKernel>
276 inline void parallel_for_gsize(
279#ifdef SHAMROCK_USE_NVTX
283 if constexpr (mode == PARALLEL_FOR) {
285 cgh.parallel_for(sycl::range<1>{length}, [=](sycl::item<1>
id) {
289 }
else if constexpr (mode == PARALLEL_FOR_ROUND) {
293 cgh.parallel_for(sycl::range<1>{
len}, [=](sycl::item<1>
id) {
294 u64 gid =
id.get_linear_id();
301 }
else if constexpr (mode == ND_RANGE) {
303 cgh.parallel_for(
make_range(length, group_size), [=](sycl::nd_item<1>
id) {
304 u64 gid =
id.get_global_linear_id();
315#ifdef SHAMROCK_USE_NVTX
321 shamlog_debug_sycl_ln(
"SYCL",
"checking queue state", loc.format_one_line());
323 shamlog_debug_sycl_ln(
"SYCL",
"checking queue state : OK !");
std::uint32_t u32
32 bit unsigned integer
std::uint64_t u64
64 bit unsigned integer
This header file contains utility functions related to exception handling in the code.
namespace for basic c++ utilities
constexpr u32 group_count(u32 len, u32 group_size)
Calculates the number of groups based on the length and group size.
void throw_with_loc(std::string message, SourceLocation loc=SourceLocation{})
Throw an exception and append the source location to it.
sycl::nd_range< 1 > make_range(u32 length, const u32 group_size=32)
Generate a sycl nd range out of a group size and length.
void check_buffer_size(sycl::buffer< T > &buf, u64 max_range, const SourceLocation loc=SourceLocation())
check that the size of a sycl buffer is below or equal to the value of max range throw if it is not t...
std::string getDevice_type(const sycl::device &Device)
Get the Device Type Name.
void throw_unimplemented(SourceLocation loc=SourceLocation{})
Throw a std::runtime_error saying that the function is unimplemented.
This file contains the definition for the stacktrace related functionality.
provide information about the source location