Shamrock 2025.10.0
Astrophysical Code
Loading...
Searching...
No Matches
kernel_call.hpp
Go to the documentation of this file.
1// -------------------------------------------------------//
2//
3// SHAMROCK code for hydrodynamics
4// Copyright (c) 2021-2026 Timothée David--Cléris <tim.shamrock@proton.me>
5// SPDX-License-Identifier: CeCILL Free Software License Agreement v2.1
6// Shamrock is licensed under the CeCILL 2.1 License, see LICENSE for more information
7//
8// -------------------------------------------------------//
9
10#pragma once
11
19#include "shambase/optional.hpp"
21#include <functional>
22#include <optional>
23
24namespace sham {
25
26 namespace details {
27
37 template<class T>
40 if (!buffer.has_value()) {
41 return nullptr;
42 } else {
43 return buffer.value().get().get_read_access(depends_list);
44 }
45 }
46
56 template<class T>
59 if (!buffer.has_value()) {
60 return nullptr;
61 } else {
62 return buffer.value().get().get_write_access(depends_list);
63 }
64 }
65
71 template<class T>
72 void complete_state_optional(sycl::event e, shambase::opt_ref<T> buffer) {
73 if (buffer.has_value()) {
74 buffer.value().get().complete_event_state(e);
75 }
76 }
77
78 template<class Obj>
79 inline auto get_read_access(Obj &o, sham::EventList &depends_list) {
80 return o.get_read_access(depends_list);
81 }
82
83 template<class Obj>
84 inline auto get_write_access(Obj &o, sham::EventList &depends_list) {
85 return o.get_write_access(depends_list);
86 }
87 template<class Obj>
88 inline auto complete_event_state(Obj &o, sycl::event e) {
89 return o.complete_event_state(e);
90 }
91
92 template<class Obj>
93 inline auto get_read_access(std::reference_wrapper<Obj> &o, sham::EventList &depends_list) {
94 return o.get().get_read_access(depends_list);
95 }
96
97 template<class Obj>
98 inline auto get_write_access(
99 std::reference_wrapper<Obj> &o, sham::EventList &depends_list) {
100 return o.get().get_write_access(depends_list);
101 }
102 template<class Obj>
103 inline auto complete_event_state(std::reference_wrapper<Obj> &o, sycl::event e) {
104 return o.get().complete_event_state(e);
105 }
106
107 } // namespace details
108
115 template<class T>
117 return t;
118 }
119
127 template<class T>
131
140 template<class... Targ>
141 struct MultiRefOpt {
143 using storage_t = std::tuple<shambase::opt_ref<Targ>...>;
144
147
150
160 auto get_read_access(sham::EventList &depends_list) {
162 return std::apply(
163 [&](auto &...__a) {
164 return std::tuple(details::read_access_optional(__a, depends_list)...);
165 },
166 storage);
167 }
177 auto get_write_access(sham::EventList &depends_list) {
179 return std::apply(
180 [&](auto &...__a) {
181 return std::tuple(details::write_access_optional(__a, depends_list)...);
182 },
183 storage);
184 }
185
193 void complete_event_state(sycl::event e) {
195 std::apply(
196 [&](auto &...__in) {
197 ((details::complete_state_optional(e, __in)), ...);
198 },
199 storage);
200 }
201 };
202
203 namespace details {
205 template<class T>
206 struct mapper {
208 using type = T;
209 };
210
212 template<class T>
213 struct mapper<shambase::opt_ref<T>> {
215 using type = T;
216 };
217 } // namespace details
218
220 template<class... Targ>
222
233 template<class... Targ>
234 struct MultiRef {
236 using storage_t = std::tuple<Targ &...>;
237
240
242 MultiRef(Targ &...arg) : storage(arg...) {}
243
246 auto get_read_access(sham::EventList &depends_list) {
248 return std::apply(
249 [&](auto &...__a) {
250 return std::tuple(details::get_read_access(__a, depends_list)...);
251 },
252 storage);
253 }
254
257 auto get_write_access(sham::EventList &depends_list) {
259 return std::apply(
260 [&](auto &...__a) {
261 return std::tuple(details::get_write_access(__a, depends_list)...);
262 },
263 storage);
264 }
265
268 void complete_event_state(sycl::event e) {
270 std::apply(
271 [&](auto &...__in) {
272 ((details::complete_event_state(__in, e)), ...);
273 },
274 storage);
275 }
276 };
277
278 namespace details {
279
281 template<class index_t, class RefIn, class RefOut, class Functor>
284 RefIn in,
285 RefOut in_out,
286 index_t n,
287 Functor &&kernel_gen,
288 SourceLocation &&callsite = SourceLocation{}) {
289
291
292 if (n == 0) {
293 shambase::throw_with_loc<std::runtime_error>("kernel call with : n == 0");
294 }
295
296 sham::EventList depends_list;
297
298 auto acc_in = in.get_read_access(depends_list);
299 auto acc_in_out = in_out.get_write_access(depends_list);
300
301 sycl::event e;
302
303 // unpack the tuples of accessors
304 std::apply(
305 [&](auto &...__acc_in) {
306 std::apply(
307 [&](auto &...__acc_in_out) {
308 // submit the kernel generated by the functor
309 e = q.submit(depends_list, kernel_gen(n, __acc_in..., __acc_in_out...));
310 },
311 acc_in_out);
312 },
313 acc_in);
314
315 in.complete_event_state(e);
316 in_out.complete_event_state(e);
317 }
318
320 template<class index_t, class RefIn, class RefOut, class Functor>
323 RefIn in,
324 RefOut in_out,
325 index_t n,
326 Functor &&func,
327 SourceLocation &&callsite = SourceLocation{}) {
328
329 __shamrock_log_callsite(callsite);
330
331 typed_index_kernel_call_lambda(
332 q,
333 in,
334 in_out,
335 n,
336 [func
337 = std::forward<Functor>(func)](u32 n, auto... __acc_in, auto... __acc_in_out) {
338 return [=](sycl::handler &cgh) {
339 cgh.parallel_for(sycl::range<1>{n}, [=](sycl::item<1> item) {
341 func, index_t(item.get_linear_id()), __acc_in..., __acc_in_out...);
342
343 func(index_t(item.get_linear_id()), __acc_in..., __acc_in_out...);
344 });
345 };
346 });
347 }
348 } // namespace details
349
513 template<class RefIn, class RefOut, class Functor>
516 RefIn in,
517 RefOut in_out,
518 u32 n,
519 Functor &&func,
520 SourceLocation &&callsite = SourceLocation{}) {
521
522 __shamrock_log_callsite(callsite);
523
524 details::typed_index_kernel_call<u32, RefIn, RefOut>(
525 q, in, in_out, n, std::forward<Functor>(func));
526 }
527
529 template<class RefIn, class RefOut, class Functor>
532 RefIn in,
533 RefOut in_out,
534 u64 n,
535 Functor &&func,
536 SourceLocation &&callsite = SourceLocation{}) {
537
538 __shamrock_log_callsite(callsite);
539
540 details::typed_index_kernel_call<u64, RefIn, RefOut>(
541 q, in, in_out, n, std::forward<Functor>(func));
542 }
543
544 // version where one supplies a kernel generator in the form of [&](sycl::handler &cgh) { ... }
545 template<class RefIn, class RefOut, class Functor>
546 void kernel_call_hndl(
548 RefIn in,
549 RefOut in_out,
550 u32 n,
551 Functor &&kernel_gen,
552 SourceLocation &&callsite = SourceLocation{}) {
553
554 __shamrock_log_callsite(callsite);
555
556 details::typed_index_kernel_call_lambda<u32, RefIn, RefOut>(
557 q, in, in_out, n, std::forward<Functor>(kernel_gen));
558 }
559
561 template<class RefIn, class RefOut, class Functor>
564 RefIn in,
565 RefOut in_out,
566 u64 n,
567 Functor &&kernel_gen,
568 SourceLocation &&callsite = SourceLocation{}) {
569
570 __shamrock_log_callsite(callsite);
571
572 details::typed_index_kernel_call_lambda<u64, RefIn, RefOut>(
573 q, in, in_out, n, std::forward<Functor>(kernel_gen));
574 }
575
576} // namespace sham
std::uint32_t u32
32 bit unsigned integer
std::uint64_t u64
64 bit unsigned integer
A buffer allocated in USM (Unified Shared Memory)
A SYCL queue associated with a device and a context.
sycl::event submit(Fct &&fct)
Submits a kernel to the SYCL queue.
Class to manage a list of SYCL events.
Definition EventList.hpp:31
void typed_index_kernel_call(sham::DeviceQueue &q, RefIn in, RefOut in_out, index_t n, Functor &&func, SourceLocation &&callsite=SourceLocation{})
internal implementation of typed_index_kernel_call
const T * read_access_optional(shambase::opt_ref< sham::DeviceBuffer< T > > buffer, sham::EventList &depends_list)
Get a pointer to the data of an optional device buffer, for reading.
void complete_state_optional(sycl::event e, shambase::opt_ref< T > buffer)
Complete the event state of an optional device buffer.
void typed_index_kernel_call_lambda(sham::DeviceQueue &q, RefIn in, RefOut in_out, index_t n, Functor &&kernel_gen, SourceLocation &&callsite=SourceLocation{})
internal implementation of typed_index_kernel_call
T * write_access_optional(shambase::opt_ref< sham::DeviceBuffer< T > > buffer, sham::EventList &depends_list)
Get a pointer to the data of an optional device buffer, for writing.
Namespace for internal details of the logs module.
namespace for backends this one is named only sham since shambackends is too long to write
shambase::opt_ref< T > to_opt_ref(T &t)
Converts a reference to a given object into an optional reference wrapper.
auto empty_buf_ref()
Returns an empty optional containing a reference to a sham::DeviceBuffer<T>.
void kernel_call_u64(sham::DeviceQueue &q, RefIn in, RefOut in_out, u64 n, Functor &&func, SourceLocation &&callsite=SourceLocation{})
u64 indexed variant of kernel_call
void kernel_call(sham::DeviceQueue &q, RefIn in, RefOut in_out, u32 n, Functor &&func, SourceLocation &&callsite=SourceLocation{})
Submit a kernel to a SYCL queue.
void kernel_call_hndl_u64(sham::DeviceQueue &q, RefIn in, RefOut in_out, u64 n, Functor &&kernel_gen, SourceLocation &&callsite=SourceLocation{})
u64 indexed variant of kernel_call_hndl
namespace for basic c++ utilities
void throw_with_loc(std::string message, SourceLocation loc=SourceLocation{})
Throw an exception and append the source location to it.
std::optional< std::reference_wrapper< T > > opt_ref
Optional reference wrapper.
Definition optional.hpp:26
#define __shamrock_stack_entry_with_callsite(callsite)
Macro to create a stack entry.
#define __shamrock_stack_entry()
Macro to create a stack entry.
#define __shamrock_log_callsite(callsite)
Macro to create a stack entry from a given location. Can be used only on SourceLocation &&.
provide information about the source location
A variant of MultiRef for optional buffers.
void complete_event_state(sycl::event e)
Complete the event state of the buffers.
auto get_write_access(sham::EventList &depends_list)
Get a tuple of pointers to the data of the buffers, for writing.
std::tuple< shambase::opt_ref< Targ >... > storage_t
A tuple of optional references to the buffers.
MultiRefOpt(shambase::opt_ref< Targ >... arg)
Constructor from a tuple of optional references to the buffers.
auto get_read_access(sham::EventList &depends_list)
Get a tuple of pointers to the data of the buffers, for reading.
storage_t storage
The tuple of optional references to the buffers.
A class that references multiple buffers or similar objects.
storage_t storage
A tuple of references to the buffers.
auto get_write_access(sham::EventList &depends_list)
std::tuple< Targ &... > storage_t
A tuple of references to the buffers.
MultiRef(Targ &...arg)
Constructor.
auto get_read_access(sham::EventList &depends_list)
void complete_event_state(sycl::event e)
internal_utility for MultiRef template deduction guide
T type
The mapped type.