Shamrock 2025.10.0
Astrophysical Code
Loading...
Searching...
No Matches
sycl_mpi_interop.cpp
Go to the documentation of this file.
1// -------------------------------------------------------//
2//
3// SHAMROCK code for hydrodynamics
4// Copyright (c) 2021-2026 Timothée David--Cléris <tim.shamrock@proton.me>
5// SPDX-License-Identifier: CeCILL Free Software License Agreement v2.1
6// Shamrock is licensed under the CeCILL 2.1 License, see LICENSE for more information
7//
8// -------------------------------------------------------//
9
16#include "shambase/memory.hpp"
17#include "shamcomm/wrapper.hpp"
20
21namespace impl::copy_to_host {
22
23 // using namespace mpi_sycl_interop;
24
25 namespace send {
26 template<class T>
27 T *init(const std::unique_ptr<sycl::buffer<T>> &buf, u32 comm_sz) {
28
29 using namespace shamsys::instance;
30
31 T *comm_ptr = sycl::malloc_host<T>(comm_sz, get_compute_queue());
32 get_compute_queue().wait();
33 shamlog_debug_sycl_ln(
34 "PatchDataField MPI Comm",
35 "sycl::malloc_host",
36 comm_sz,
37 "->",
38 reinterpret_cast<void *>(comm_ptr));
39
40 if (comm_sz > 0) {
41 shamlog_debug_sycl_ln("PatchDataField MPI Comm", "copy buffer -> USM");
42
43 {
44 sycl::host_accessor acc{shambase::get_check_ref(buf), sycl::read_only};
45
46 const T *src = &(acc[0]);
47 T *dest = comm_ptr;
48
49 std::memcpy(dest, src, sizeof(T) * comm_sz);
50 }
51
52 } else {
53 shamlog_debug_sycl_ln(
54 "PatchDataField MPI Comm", "copy buffer -> USM (skipped size=0)");
55 }
56
57 return comm_ptr;
58 }
59
60#define X(_t) template _t *init<_t>(const std::unique_ptr<sycl::buffer<_t>> &buf, u32 comm_sz);
61 XMAC_SYCLMPI_TYPE_ENABLED
62#undef X
63
64 template<class T>
65 void finalize(T *comm_ptr) {
66
67 using namespace shamsys::instance;
68
69 shamlog_debug_sycl_ln(
70 "PatchDataField MPI Comm", "sycl::free", reinterpret_cast<void *>(comm_ptr));
71
72 sycl::free(comm_ptr, get_compute_queue());
73 }
74
75#define X(_t) template void finalize(_t *comm_ptr);
76 XMAC_SYCLMPI_TYPE_ENABLED
77#undef X
78 } // namespace send
79
80 namespace recv {
81 template<class T>
82 T *init(u32 comm_sz) {
83
84 using namespace shamsys::instance;
85
86 T *comm_ptr = sycl::malloc_host<T>(comm_sz, shamsys::instance::get_compute_queue());
87
88 shamlog_debug_sycl_ln("PatchDataField MPI Comm", "sycl::malloc_host", comm_sz);
89
90 return comm_ptr;
91 };
92
93#define X(_t) template _t *init(u32 comm_sz);
94 XMAC_SYCLMPI_TYPE_ENABLED
95#undef X
96 template<class T>
97 void finalize(const std::unique_ptr<sycl::buffer<T>> &buf, T *comm_ptr, u32 comm_sz) {
98
99 if (comm_sz > 0) {
100 shamlog_debug_sycl_ln("PatchDataField MPI Comm", "copy USM -> buffer");
101
102 {
103 sycl::host_accessor acc{
104 shambase::get_check_ref(buf), sycl::write_only, sycl::no_init};
105
106 const T *src = comm_ptr;
107 T *dest = &(acc[0]);
108
109 std::memcpy(dest, src, sizeof(T) * comm_sz);
110 }
111
112 } else {
113 shamlog_debug_sycl_ln(
114 "PatchDataField MPI Comm", "copy USM -> buffer (skipped size=0)");
115 }
116
117 shamlog_debug_sycl_ln(
118 "PatchDataField MPI Comm", "sycl::free", reinterpret_cast<void *>(comm_ptr));
119
120 sycl::free(comm_ptr, shamsys::instance::get_compute_queue());
121 }
122
123#define X(_t) \
124 template void finalize(const std::unique_ptr<sycl::buffer<_t>> &buf, _t *comm_ptr, u32 comm_sz);
125 XMAC_SYCLMPI_TYPE_ENABLED
126#undef X
127 } // namespace recv
128
129} // namespace impl::copy_to_host
130
131namespace impl::directgpu {
132
133 using namespace mpi_sycl_interop;
134
135 namespace send {
136 template<class T>
137 T *init(const std::unique_ptr<sycl::buffer<T>> &buf, u32 comm_sz) {
138
139 T *comm_ptr = sycl::malloc_device<T>(comm_sz, shamsys::instance::get_compute_queue());
140 shamlog_debug_sycl_ln(
141 "PatchDataField MPI Comm", "sycl::malloc_device", comm_sz, "->", comm_ptr);
142
143 if (comm_sz > 0) {
144 shamlog_debug_sycl_ln("PatchDataField MPI Comm", "copy buffer -> USM");
145
146 auto ker_copy
147 = shamsys::instance::get_compute_queue().submit([&](sycl::handler &cgh) {
148 sycl::accessor acc{*buf, cgh, sycl::read_only};
149
150 T *ptr = comm_ptr;
151
152 cgh.parallel_for(sycl::range<1>{comm_sz}, [=](sycl::item<1> item) {
153 ptr[item.get_linear_id()] = acc[item];
154 });
155 });
156
157 ker_copy.wait();
158 } else {
159 shamlog_debug_sycl_ln(
160 "PatchDataField MPI Comm", "copy buffer -> USM (skipped size=0)");
161 }
162
163 return comm_ptr;
164 }
165
166#define X(_t) template _t *init<_t>(const std::unique_ptr<sycl::buffer<_t>> &buf, u32 comm_sz);
167 XMAC_SYCLMPI_TYPE_ENABLED
168#undef X
169
170 template<class T>
171 void finalize(T *comm_ptr) {
172 shamlog_debug_sycl_ln("PatchDataField MPI Comm", "sycl::free", comm_ptr);
173
174 sycl::free(comm_ptr, shamsys::instance::get_compute_queue());
175 }
176
177#define X(_t) template void finalize(_t *comm_ptr);
178 XMAC_SYCLMPI_TYPE_ENABLED
179#undef X
180 } // namespace send
181
182 namespace recv {
183 template<class T>
184 T *init(u32 comm_sz) {
185 T *comm_ptr = sycl::malloc_device<T>(comm_sz, shamsys::instance::get_compute_queue());
186
187 shamlog_debug_sycl_ln("PatchDataField MPI Comm", "sycl::malloc_device", comm_sz);
188
189 return comm_ptr;
190 };
191
192#define X(_t) template _t *init(u32 comm_sz);
193 XMAC_SYCLMPI_TYPE_ENABLED
194#undef X
195 template<class T>
196 void finalize(const std::unique_ptr<sycl::buffer<T>> &buf, T *comm_ptr, u32 comm_sz) {
197
198 if (comm_sz > 0) {
199 shamlog_debug_sycl_ln("PatchDataField MPI Comm", "copy USM -> buffer");
200
201 auto ker_copy
202 = shamsys::instance::get_compute_queue().submit([&](sycl::handler &cgh) {
203 sycl::accessor acc{*buf, cgh, sycl::write_only};
204
205 T *ptr = comm_ptr;
206
207 cgh.parallel_for(sycl::range<1>{comm_sz}, [=](sycl::item<1> item) {
208 acc[item] = ptr[item.get_linear_id()];
209 });
210 });
211
212 ker_copy.wait();
213 } else {
214 shamlog_debug_sycl_ln(
215 "PatchDataField MPI Comm", "copy USM -> buffer (skipped size=0)");
216 }
217
218 shamlog_debug_sycl_ln("PatchDataField MPI Comm", "sycl::free", comm_ptr);
219
220 sycl::free(comm_ptr, shamsys::instance::get_compute_queue());
221 }
222#define X(_t) \
223 template void finalize(const std::unique_ptr<sycl::buffer<_t>> &buf, _t *comm_ptr, u32 comm_sz);
224 XMAC_SYCLMPI_TYPE_ENABLED
225#undef X
226
227 } // namespace recv
228
229} // namespace impl::directgpu
230
231namespace mpi_sycl_interop {
232
233 comm_type current_mode = CopyToHost;
234
235 template<class T>
237 std::unique_ptr<sycl::buffer<T>> &sycl_buf,
238 comm_type comm_mode,
239 op_type comm_op,
240 u32 comm_sz)
241 : comm_mode(comm_mode), comm_op(comm_op), comm_sz(comm_sz), sycl_buf(sycl_buf) {
242
243 shamlog_debug_mpi_ln(
244 "PatchDataField MPI Comm",
245 "starting mpi sycl comm ",
246 comm_sz,
247 int(comm_op),
248 int(comm_mode));
249
250 if (comm_mode == CopyToHost && comm_op == Send) {
251
252 comm_ptr = impl::copy_to_host::send::init<T>(sycl_buf, comm_sz);
253
254 } else if (comm_mode == CopyToHost && comm_op == Recv_Probe) {
255
256 comm_ptr = impl::copy_to_host::recv::init<T>(comm_sz);
257
258 } else if (comm_mode == DirectGPU && comm_op == Send) {
259
260 comm_ptr = impl::directgpu::send::init<T>(sycl_buf, comm_sz);
261
262 } else if (comm_mode == DirectGPU && comm_op == Recv_Probe) {
263
264 comm_ptr = impl::directgpu::recv::init<T>(comm_sz);
265
266 } else {
267 logger::err_ln(
268 "PatchDataField MPI Comm",
269 "communication mode & op combination not implemented :",
270 int(comm_mode),
271 int(comm_op));
272 }
273 }
274
275 template<class T>
277
278 shamlog_debug_mpi_ln(
279 "PatchDataField MPI Comm",
280 "finalizing mpi sycl comm ",
281 comm_sz,
282 int(comm_op),
283 int(comm_mode));
284
285 sycl_buf = std::make_unique<sycl::buffer<T>>(comm_sz);
286
287 if (comm_mode == CopyToHost && comm_op == Send) {
288
289 impl::copy_to_host::send::finalize<T>(comm_ptr);
290
291 } else if (comm_mode == CopyToHost && comm_op == Recv_Probe) {
292
293 impl::copy_to_host::recv::finalize<T>(sycl_buf, comm_ptr, comm_sz);
294
295 } else if (comm_mode == DirectGPU && comm_op == Send) {
296
297 impl::directgpu::send::finalize<T>(comm_ptr);
298
299 } else if (comm_mode == DirectGPU && comm_op == Recv_Probe) {
300
301 impl::directgpu::recv::finalize<T>(sycl_buf, comm_ptr, comm_sz);
302
303 } else {
304 logger::err_ln(
305 "PatchDataField MPI Comm",
306 "communication mode & op combination not implemented :",
307 int(comm_mode),
308 int(comm_op));
309 }
310 }
311
312#define X(a) template struct BufferMpiRequest<a>;
313 XMAC_SYCLMPI_TYPE_ENABLED
314#undef X
315
316} // namespace mpi_sycl_interop
sycl::queue & get_compute_queue(u32 id=0)
std::uint32_t u32
32 bit unsigned integer
T & get_check_ref(const std::unique_ptr< T > &ptr, SourceLocation loc=SourceLocation())
Takes a std::unique_ptr and returns a reference to the object it holds. It throws a std::runtime_erro...
Definition memory.hpp:110
@ CopyToHost
copy data to the host and then perform the call
header file to manage sycl