Shamrock 2025.10.0
Astrophysical Code
Loading...
Searching...
No Matches
sycl_mpi_interop.hpp
Go to the documentation of this file.
1// -------------------------------------------------------//
2//
3// SHAMROCK code for hydrodynamics
4// Copyright (c) 2021-2026 Timothée David--Cléris <tim.shamrock@proton.me>
5// SPDX-License-Identifier: CeCILL Free Software License Agreement v2.1
6// Shamrock is licensed under the CeCILL 2.1 License, see LICENSE for more information
7//
8// -------------------------------------------------------//
9
10#pragma once
11
23#include "log.hpp"
25#include "shamcomm/wrapper.hpp"
28
29#define XMAC_SYCLMPI_TYPE_ENABLED \
30 X(f32) \
31 X(f32_2) \
32 X(f32_3) \
33 X(f32_4) \
34 X(f32_8) \
35 X(f32_16) \
36 X(f64) \
37 X(f64_2) \
38 X(f64_3) \
39 X(f64_4) \
40 X(f64_8) \
41 X(f64_16) \
42 X(u8) \
43 X(u32) \
44 X(u32_3) \
45 X(u16_3) \
46 X(u64) \
47 X(u64_3) \
48 X(i64_3) \
49 X(i64)
50
51namespace mpi_sycl_interop {
52
53 enum comm_type { CopyToHost, DirectGPU };
54 enum op_type { Send, Recv_Probe };
55
56 extern comm_type current_mode;
57
58 template<class T>
60
61 static constexpr bool is_in_type_list =
62#define X(args) std::is_same<T, args>::value ||
63 XMAC_SYCLMPI_TYPE_ENABLED false
64#undef X
65 ;
66
67 static_assert(
68 is_in_type_list,
69 "BufferMpiRequest must be one of those types : "
70
71#define X(args) #args " "
72 XMAC_SYCLMPI_TYPE_ENABLED
73#undef X
74 );
75
76 MPI_Request mpi_rq;
77 comm_type comm_mode;
78 op_type comm_op;
79 T *comm_ptr;
80 u32 comm_sz;
81 std::unique_ptr<sycl::buffer<T>> &sycl_buf;
82
84 std::unique_ptr<sycl::buffer<T>> &sycl_buf,
85 comm_type comm_mode,
86 op_type comm_op,
87 u32 comm_sz);
88
89 inline T *get_mpi_ptr() { return comm_ptr; }
90
91 void finalize();
92 };
93
94 template<class T>
95 inline u64 isend(
96 std::unique_ptr<sycl::buffer<T>> &p,
97 const u32 &size_comm,
98 std::vector<BufferMpiRequest<T>> &rq_lst,
99 i32 rank_dest,
100 i32 tag,
101 MPI_Comm comm) {
102
103 rq_lst.push_back(BufferMpiRequest<T>(p, current_mode, Send, size_comm));
104
105 u32 rq_index = rq_lst.size() - 1;
106
107 auto &rq = rq_lst[rq_index];
108
109 shamcomm::mpi::Isend(
110 rq.get_mpi_ptr(),
111 size_comm,
112 get_mpi_type<T>(),
113 rank_dest,
114 tag,
115 comm,
116 &(rq_lst[rq_index].mpi_rq));
117
118 return sizeof(T) * size_comm;
119 }
120
121 template<class T>
122 inline u64 irecv(
123 std::unique_ptr<sycl::buffer<T>> &p,
124 const u32 &size_comm,
125 std::vector<BufferMpiRequest<T>> &rq_lst,
126 i32 rank_source,
127 i32 tag,
128 MPI_Comm comm) {
129
130 rq_lst.push_back(BufferMpiRequest<T>(p, current_mode, Recv_Probe, size_comm));
131
132 u32 rq_index = rq_lst.size() - 1;
133
134 auto &rq = rq_lst[rq_index];
135
136 shamcomm::mpi::Irecv(
137 rq.get_mpi_ptr(),
138 size_comm,
139 get_mpi_type<T>(),
140 rank_source,
141 tag,
142 comm,
143 &(rq_lst[rq_index].mpi_rq));
144
145 return sizeof(T) * size_comm;
146 }
147
148 template<class T>
149 inline u64 irecv_probe(
150 std::unique_ptr<sycl::buffer<T>> &p,
151 std::vector<BufferMpiRequest<T>> &rq_lst,
152 i32 rank_source,
153 i32 tag,
154 MPI_Comm comm) {
155 MPI_Status st;
156 i32 cnt;
157 shamcomm::mpi::Probe(rank_source, tag, comm, &st);
158 shamcomm::mpi::Get_count(&st, get_mpi_type<T>(), &cnt);
159
160 u32 len = cnt;
161
162 return irecv(p, len, rq_lst, rank_source, tag, comm);
163 }
164
165 template<class T>
166 inline std::vector<MPI_Request> get_rqs(std::vector<BufferMpiRequest<T>> &rq_lst) {
167 std::vector<MPI_Request> addrs;
168
169 for (auto a : rq_lst) {
170 addrs.push_back(a.mpi_rq);
171 }
172
173 return addrs;
174 }
175
176 template<class T>
177 inline void waitall(std::vector<BufferMpiRequest<T>> &rq_lst) {
178 std::vector<MPI_Request> addrs;
179
180 for (auto a : rq_lst) {
181 addrs.push_back(a.mpi_rq);
182 }
183
184 std::vector<MPI_Status> st_lst(addrs.size());
185 shamcomm::mpi::Waitall(addrs.size(), addrs.data(), st_lst.data());
186
187 for (auto a : rq_lst) {
188 a.finalize();
189 }
190 }
191
192 template<class T>
193 inline void file_write(MPI_File fh, std::unique_ptr<sycl::buffer<T>> &p, const u32 &size_comm) {
194 MPI_Status st;
195
196 BufferMpiRequest<T> rq(p, current_mode, Send, size_comm);
197
198 shamcomm::mpi::File_write(fh, rq.get_mpi_ptr(), size_comm, get_mpi_type<T>(), &st);
199
200 rq.finalize();
201 }
202
203} // namespace mpi_sycl_interop
204
205namespace impl::copy_to_host {
206
207 namespace send {
208 template<class T>
209 T *init(const std::unique_ptr<sycl::buffer<T>> &buf, u32 comm_sz);
210
211 template<class T>
212 void finalize(T *comm_ptr);
213 } // namespace send
214
215 namespace recv {
216 template<class T>
217 T *init(u32 comm_sz);
218
219 template<class T>
220 void finalize(const std::unique_ptr<sycl::buffer<T>> &buf, T *comm_ptr, u32 comm_sz);
221 } // namespace recv
222
223} // namespace impl::copy_to_host
224
225namespace impl::directgpu {
226
227 namespace send {
228 template<class T>
229 T *init(const std::unique_ptr<sycl::buffer<T>> &buf, u32 comm_sz);
230
231 template<class T>
232 void finalize(T *comm_ptr);
233 } // namespace send
234
235 namespace recv {
236 template<class T>
237 T *init(u32 comm_sz);
238
239 template<class T>
240 void finalize(const std::unique_ptr<sycl::buffer<T>> &buf, T *comm_ptr, u32 comm_sz);
241 } // namespace recv
242
243} // namespace impl::directgpu
This header does the MPI include and wrap MPI calls.
Header file describing a Node Instance.
std::uint32_t u32
32 bit unsigned integer
std::uint64_t u64
64 bit unsigned integer
std::int32_t i32
32 bit integer
@ CopyToHost
copy data to the host and then perform the call
@ DirectGPU
copy data straight from the GPU
void Get_count(const MPI_Status *status, MPI_Datatype datatype, int *count)
MPI wrapper for MPI_Get_count.
Definition wrapper.cpp:222
void Probe(int source, int tag, MPI_Comm comm, MPI_Status *status)
MPI wrapper for MPI_Probe.
Definition wrapper.cpp:201
void Waitall(int count, MPI_Request array_of_requests[], MPI_Status *array_of_statuses)
MPI wrapper for MPI_Waitall.
Definition wrapper.cpp:187
void File_write(MPI_File fh, const void *buf, int count, MPI_Datatype datatype, MPI_Status *status)
MPI wrapper for MPI_File_write.
Definition wrapper.cpp:264