Shamrock 2025.10.0
Astrophysical Code
Loading...
Searching...
No Matches
kernel_call_distrib.hpp
Go to the documentation of this file.
1// -------------------------------------------------------//
2//
3// SHAMROCK code for hydrodynamics
4// Copyright (c) 2021-2026 Timothée David--Cléris <tim.shamrock@proton.me>
5// SPDX-License-Identifier: CeCILL Free Software License Agreement v2.1
6// Shamrock is licensed under the CeCILL 2.1 License, see LICENSE for more information
7//
8// -------------------------------------------------------//
9
10#pragma once
11
21
22namespace sham {
23
32 template<class... Targ>
33 struct DDMultiRef {
35 using storage_t = std::tuple<Targ &...>;
36
39
41 DDMultiRef(Targ &...arg) : storage(arg...) {}
42
51 auto get(u64 id) {
52 shamlog_debug_ln(
53 "kern call",
54 "called DDMultiRef.get, id =",
55 id,
57 return std::apply(
58 [id](auto &...args) {
59 return sham::MultiRef{args.get(id)...};
60 },
61 storage);
62 }
63 };
64
78 template<class index_t, class RefIn, class RefOut, class Functor>
80 sham::DeviceScheduler_ptr dev_sched,
81 RefIn in,
82 RefOut in_out,
83 const shambase::DistributedData<index_t> &thread_counts,
84 Functor &&func) {
85
86 auto mrefs_in
87 = thread_counts.template map<decltype(in.get(0))>([&](u64 id, const index_t &n) {
88 shamlog_debug_ln("kern call", "build multi ref in for patch", id);
89 return in.get(id);
90 });
91
92 auto mrefs_in_out
93 = thread_counts.template map<decltype(in_out.get(0))>([&](u64 id, const index_t &n) {
94 shamlog_debug_ln("kern call", "build multi ref in_out for patch", id);
95 return in_out.get(id);
96 });
97
98 thread_counts.for_each([&](u64 id, const index_t &n) {
99 shamlog_debug_ln(
100 "kern call", "calling sham::kernel_call on patch", id, " thread count", n);
102 dev_sched->get_queue(),
103 mrefs_in.get(id),
104 mrefs_in_out.get(id),
105 n,
106 std::forward<Functor>(func));
107 });
108 }
109
110 // version where one supplies a kernel generator in the form of [&](sycl::handler &cgh) { ... }
111 template<class index_t, class RefIn, class RefOut, class Functor>
112 inline void distributed_data_kernel_call_hndl(
113 sham::DeviceScheduler_ptr dev_sched,
114 RefIn in,
115 RefOut in_out,
116 const shambase::DistributedData<index_t> &thread_counts,
117 Functor &&kernel_gen) {
118
119 auto mrefs_in
120 = thread_counts.template map<decltype(in.get(0))>([&](u64 id, const index_t &n) {
121 shamlog_debug_ln("kern call", "build multi ref in for patch", id);
122 return in.get(id);
123 });
124
125 auto mrefs_in_out
126 = thread_counts.template map<decltype(in_out.get(0))>([&](u64 id, const index_t &n) {
127 shamlog_debug_ln("kern call", "build multi ref in_out for patch", id);
128 return in_out.get(id);
129 });
130
131 thread_counts.for_each([&](u64 id, const index_t &n) {
132 shamlog_debug_ln(
133 "kern call", "calling sham::kernel_call_hndl on patch", id, " thread count", n);
134
135 sham::kernel_call_hndl(
136 dev_sched->get_queue(),
137 mrefs_in.get(id),
138 mrefs_in_out.get(id),
139 n,
140 std::forward<Functor>(kernel_gen));
141 });
142 }
143
144} // namespace sham
std::uint64_t u64
64 bit unsigned integer
Represents a collection of objects distributed across patches identified by a u64 id.
void for_each(std::function< void(u64, T &)> &&f)
Applies a function to each object in the collection.
namespace for backends this one is named only sham since shambackends is too long to write
void distributed_data_kernel_call(sham::DeviceScheduler_ptr dev_sched, RefIn in, RefOut in_out, const shambase::DistributedData< index_t > &thread_counts, Functor &&func)
A variant of sham::kernel_call for distributed data.
void kernel_call(sham::DeviceQueue &q, RefIn in, RefOut in_out, u32 n, Functor &&func, SourceLocation &&callsite=SourceLocation{})
Submit a kernel to a SYCL queue.
provide information about the source location
std::string format_one_line_func() const
format the location in a one liner with the function name displayed
A variant of sham::MultiRef for distributed data.
auto get(u64 id)
Get a MultiRef at a given id.
DDMultiRef(Targ &...arg)
Constructor.
std::tuple< Targ &... > storage_t
A tuple of references to the buffers.
storage_t storage
A tuple of references to the buffers.
A class that references multiple buffers or similar objects.