Shamrock 2025.10.0
Astrophysical Code
Loading...
Searching...
No Matches
wrapper.cpp
Go to the documentation of this file.
1// -------------------------------------------------------//
2//
3// SHAMROCK code for hydrodynamics
4// Copyright (c) 2021-2026 Timothée David--Cléris <tim.shamrock@proton.me>
5// SPDX-License-Identifier: CeCILL Free Software License Agreement v2.1
6// Shamrock is licensed under the CeCILL 2.1 License, see LICENSE for more information
7//
8// -------------------------------------------------------//
9
19#include "shambase/time.hpp"
22#include "shamcomm/wrapper.hpp"
23#include <unordered_map>
24#include <array>
25
26namespace {
27
28 std::unordered_map<std::string, f64> mpi_timers;
29
30} // namespace
31
32namespace shamcomm::mpi {
33 void register_time(std::string timername, f64 time) {
34 mpi_timers[timername] += time;
35 mpi_timers["total"] += time;
36
37 if (shambase::profiling::is_profiling_enabled()) {
38 auto wtime = shambase::details::get_wtime();
39 shambase::profiling::register_counter_val(timername, wtime, mpi_timers[timername]);
40 shambase::profiling::register_counter_val("total MPi time", wtime, mpi_timers["total"]);
41 }
42 }
43
44 f64 get_timer(std::string timername) { return mpi_timers[timername]; }
45
46 const std::unordered_map<std::string, f64> &get_timers() { return mpi_timers; }
47
48 std::vector<std::string> possible_keys{
49 "total", "MPI_Isend", "MPI_Irecv",
50 "MPI_Allreduce", "MPI_Allgather", "MPI_Allgatherv",
51 "MPI_Exscan", "MPI_Wait", "MPI_Waitall",
52 "MPI_Barrier", "MPI_Probe", "MPI_Recv",
53 "MPI_Get_count", "MPI_Send", "MPI_File_set_view",
54 "MPI_Type_size", "MPI_File_write_all", "MPI_File_write",
55 "MPI_File_read", "MPI_File_write_at", "MPI_File_read_at",
56 "MPI_File_close", "MPI_File_open", "MPI_Test",
57 "MPI_Gather", "MPI_Gatherv",
58 };
59
60 const std::vector<std::string> &get_possible_keys() { return possible_keys; }
61
62} // namespace shamcomm::mpi
63
64namespace {
65
66 template<class Func>
67 inline void wrap_profiling(std::string timername, Func &&f) {
68 f64 tstart;
69 tstart = shambase::details::get_wtime();
70 f();
71 shamcomm::mpi::register_time(timername, shambase::details::get_wtime() - tstart);
72 }
73
74} // namespace
75
76namespace shamcomm::mpi {
77
78 void check_tag_value(i32 tag) {
79 if (tag > mpi_max_tag_value()) {
81 "mpi_max_tag_value ({}) exceeded with tag {}", mpi_max_tag_value(), tag));
82 }
83 }
84
85 void Isend(
86 const void *buf,
87 int count,
88 MPI_Datatype datatype,
89 int dest,
90 int tag,
91 MPI_Comm comm,
92 MPI_Request *request) {
93 StackEntry stack_loc{};
94
95 check_tag_value(tag);
96
97 wrap_profiling("MPI_Isend", [&]() {
98 MPICHECK(MPI_Isend(buf, count, datatype, dest, tag, comm, request));
99 });
100 }
101
102 void Irecv(
103 void *buf,
104 int count,
105 MPI_Datatype datatype,
106 int source,
107 int tag,
108 MPI_Comm comm,
109 MPI_Request *request) {
110 StackEntry stack_loc{};
111
112 check_tag_value(tag);
113
114 wrap_profiling("MPI_Irecv", [&]() {
115 MPICHECK(MPI_Irecv(buf, count, datatype, source, tag, comm, request));
116 });
117 }
118
120 const void *sendbuf,
121 void *recvbuf,
122 int count,
123 MPI_Datatype datatype,
124 MPI_Op op,
125 MPI_Comm comm) {
126 StackEntry stack_loc{};
127
128 wrap_profiling("MPI_Allreduce", [&]() {
129 MPICHECK(MPI_Allreduce(sendbuf, recvbuf, count, datatype, op, comm));
130 });
131 }
132
134 const void *sendbuf,
135 int sendcount,
136 MPI_Datatype sendtype,
137 void *recvbuf,
138 int recvcount,
139 MPI_Datatype recvtype,
140 MPI_Comm comm) {
141 StackEntry stack_loc{};
142
143 wrap_profiling("MPI_Allgather", [&]() {
144 MPICHECK(
145 MPI_Allgather(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, comm));
146 });
147 }
148
150 const void *sendbuf,
151 int sendcount,
152 MPI_Datatype sendtype,
153 void *recvbuf,
154 const int recvcounts[],
155 const int displs[],
156 MPI_Datatype recvtype,
157 MPI_Comm comm) {
158 StackEntry stack_loc{};
159
160 wrap_profiling("MPI_Allgatherv", [&]() {
161 MPICHECK(MPI_Allgatherv(
162 sendbuf, sendcount, sendtype, recvbuf, recvcounts, displs, recvtype, comm));
163 });
164 }
165
166 void Exscan(
167 const void *sendbuf,
168 void *recvbuf,
169 int count,
170 MPI_Datatype datatype,
171 MPI_Op op,
172 MPI_Comm comm) {
173 StackEntry stack_loc{};
174
175 wrap_profiling("MPI_Exscan", [&]() {
176 MPICHECK(MPI_Exscan(sendbuf, recvbuf, count, datatype, op, comm));
177 });
178 }
179
180 void Wait(MPI_Request *request, MPI_Status *status) {
181 StackEntry stack_loc{};
182 wrap_profiling("MPI_Wait", [&]() {
183 MPICHECK(MPI_Wait(request, status));
184 });
185 }
186
187 void Waitall(int count, MPI_Request array_of_requests[], MPI_Status *array_of_statuses) {
188 StackEntry stack_loc{};
189 wrap_profiling("MPI_Waitall", [&]() {
190 MPICHECK(MPI_Waitall(count, array_of_requests, array_of_statuses));
191 });
192 }
193
194 void Barrier(MPI_Comm comm) {
195 StackEntry stack_loc{};
196 wrap_profiling("MPI_Barrier", [&]() {
197 MPICHECK(MPI_Barrier(comm));
198 });
199 }
200
201 void Probe(int source, int tag, MPI_Comm comm, MPI_Status *status) {
202 StackEntry stack_loc{};
203 wrap_profiling("MPI_Probe", [&]() {
204 MPICHECK(MPI_Probe(source, tag, comm, status));
205 });
206 }
207
208 void Recv(
209 void *buf,
210 int count,
211 MPI_Datatype datatype,
212 int source,
213 int tag,
214 MPI_Comm comm,
215 MPI_Status *status) {
216 StackEntry stack_loc{};
217 wrap_profiling("MPI_Recv", [&]() {
218 MPICHECK(MPI_Recv(buf, count, datatype, source, tag, comm, status));
219 });
220 }
221
222 void Get_count(const MPI_Status *status, MPI_Datatype datatype, int *count) {
223 StackEntry stack_loc{};
224 wrap_profiling("MPI_Get_count", [&]() {
225 MPICHECK(MPI_Get_count(status, datatype, count));
226 });
227 }
228
229 void Send(const void *buf, int count, MPI_Datatype datatype, int dest, int tag, MPI_Comm comm) {
230 StackEntry stack_loc{};
231 wrap_profiling("MPI_Send", [&]() {
232 MPICHECK(MPI_Send(buf, count, datatype, dest, tag, comm));
233 });
234 }
235
237 MPI_File fh,
238 MPI_Offset disp,
239 MPI_Datatype etype,
240 MPI_Datatype filetype,
241 const char *datarep,
242 MPI_Info info) {
243 StackEntry stack_loc{};
244 wrap_profiling("MPI_File_set_view", [&]() {
245 MPICHECK(MPI_File_set_view(fh, disp, etype, filetype, datarep, info));
246 });
247 }
248
249 void Type_size(MPI_Datatype type, int *size) {
250 StackEntry stack_loc{};
251 wrap_profiling("MPI_Type_size", [&]() {
252 MPICHECK(MPI_Type_size(type, size));
253 });
254 }
255
257 MPI_File fh, const void *buf, int count, MPI_Datatype datatype, MPI_Status *status) {
258 StackEntry stack_loc{};
259 wrap_profiling("MPI_File_write_all", [&]() {
260 MPICHECK(MPI_File_write_all(fh, buf, count, datatype, status));
261 });
262 }
263
265 MPI_File fh, const void *buf, int count, MPI_Datatype datatype, MPI_Status *status) {
266 StackEntry stack_loc{};
267 wrap_profiling("MPI_File_write", [&]() {
268 MPICHECK(MPI_File_write(fh, buf, count, datatype, status));
269 });
270 }
271
272 void File_read(MPI_File fh, void *buf, int count, MPI_Datatype datatype, MPI_Status *status) {
273 StackEntry stack_loc{};
274 wrap_profiling("MPI_File_read", [&]() {
275 MPICHECK(MPI_File_read(fh, buf, count, datatype, status));
276 });
277 }
278
280 MPI_File fh,
281 MPI_Offset offset,
282 const void *buf,
283 int count,
284 MPI_Datatype datatype,
285 MPI_Status *status) {
286 StackEntry stack_loc{};
287 wrap_profiling("MPI_File_write_at", [&]() {
288 MPICHECK(MPI_File_write_at(fh, offset, buf, count, datatype, status));
289 });
290 }
291
293 MPI_File fh,
294 MPI_Offset offset,
295 void *buf,
296 int count,
297 MPI_Datatype datatype,
298 MPI_Status *status) {
299 StackEntry stack_loc{};
300 wrap_profiling("MPI_File_read_at", [&]() {
301 MPICHECK(MPI_File_read_at(fh, offset, buf, count, datatype, status));
302 });
303 }
304
305 void File_close(MPI_File *fh) {
306 StackEntry stack_loc{};
307 wrap_profiling("MPI_File_close", [&]() {
308 MPICHECK(MPI_File_close(fh));
309 });
310 }
311
312 void File_open(MPI_Comm comm, const char *filename, int amode, MPI_Info info, MPI_File *fh) {
313 StackEntry stack_loc{};
314 wrap_profiling("MPI_File_open", [&]() {
315 MPICHECK(MPI_File_open(comm, filename, amode, info, fh));
316 });
317 }
318
319 void Test(MPI_Request *request, int *flag, MPI_Status *status) {
320 StackEntry stack_loc{};
321 wrap_profiling("MPI_Test", [&]() {
322 MPICHECK(MPI_Test(request, flag, status));
323 });
324 }
325
326 void Gather(
327 const void *sendbuf,
328 int sendcount,
329 MPI_Datatype sendtype,
330 void *recvbuf,
331 int recvcount,
332 MPI_Datatype recvtype,
333 int root,
334 MPI_Comm comm) {
335 StackEntry stack_loc{};
336 wrap_profiling("MPI_Gather", [&]() {
337 MPICHECK(
338 MPI_Gather(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, root, comm));
339 });
340 }
341
343 const void *sendbuf,
344 int sendcount,
345 MPI_Datatype sendtype,
346 void *recvbuf,
347 const int recvcounts[],
348 const int displs[],
349 MPI_Datatype recvtype,
350 int root,
351 MPI_Comm comm) {
352 StackEntry stack_loc{};
353 wrap_profiling("MPI_Gatherv", [&]() {
354 MPICHECK(MPI_Gatherv(
355 sendbuf, sendcount, sendtype, recvbuf, recvcounts, displs, recvtype, root, comm));
356 });
357 }
358
359} // namespace shamcomm::mpi
double f64
Alias for double.
std::int32_t i32
32 bit integer
Utility functions for MPI error checking.
#define MPICHECK(mpicall)
Shortcut macro to check MPI return codes.
void throw_with_loc(std::string message, SourceLocation loc=SourceLocation{})
Throw an exception and append the source location to it.
i32 mpi_max_tag_value()
Gets the maximum value of the MPI tag.
Definition worldInfo.cpp:36
This file contains the definition for the stacktrace related functionality.
Functions related to the MPI communicator.
void File_set_view(MPI_File fh, MPI_Offset disp, MPI_Datatype etype, MPI_Datatype filetype, const char *datarep, MPI_Info info)
MPI wrapper for MPI_File_set_view.
Definition wrapper.cpp:236
void Get_count(const MPI_Status *status, MPI_Datatype datatype, int *count)
MPI wrapper for MPI_Get_count.
Definition wrapper.cpp:222
void Recv(void *buf, int count, MPI_Datatype datatype, int source, int tag, MPI_Comm comm, MPI_Status *status)
MPI wrapper for MPI_Recv.
Definition wrapper.cpp:208
void Irecv(void *buf, int count, MPI_Datatype datatype, int source, int tag, MPI_Comm comm, MPI_Request *request)
MPI wrapper for MPI_Irecv.
Definition wrapper.cpp:102
void Exscan(const void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
MPI wrapper for MPI_Exscan.
Definition wrapper.cpp:166
void Gatherv(const void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf, const int recvcounts[], const int displs[], MPI_Datatype recvtype, int root, MPI_Comm comm)
MPI wrapper for MPI_Gatherv.
Definition wrapper.cpp:342
void Probe(int source, int tag, MPI_Comm comm, MPI_Status *status)
MPI wrapper for MPI_Probe.
Definition wrapper.cpp:201
void Barrier(MPI_Comm comm)
MPI wrapper for MPI_Barrier.
Definition wrapper.cpp:194
void File_close(MPI_File *fh)
MPI wrapper for MPI_File_close.
Definition wrapper.cpp:305
void File_read(MPI_File fh, void *buf, int count, MPI_Datatype datatype, MPI_Status *status)
MPI wrapper for MPI_File_read.
Definition wrapper.cpp:272
void Allgatherv(const void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf, const int recvcounts[], const int displs[], MPI_Datatype recvtype, MPI_Comm comm)
MPI wrapper for MPI_Allgatherv.
Definition wrapper.cpp:149
const std::vector< std::string > & get_possible_keys()
return all possible keys for the internal timers
Definition wrapper.cpp:60
void register_time(std::string timername, f64 time)
Register a timer value.
Definition wrapper.cpp:33
f64 get_timer(std::string timername)
get a timer value
Definition wrapper.cpp:44
void File_write_at(MPI_File fh, MPI_Offset offset, const void *buf, int count, MPI_Datatype datatype, MPI_Status *status)
MPI wrapper for MPI_File_write_at.
Definition wrapper.cpp:279
void File_open(MPI_Comm comm, const char *filename, int amode, MPI_Info info, MPI_File *fh)
MPI wrapper for MPI_File_open.
Definition wrapper.cpp:312
void Allreduce(const void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
MPI wrapper for MPI_Allreduce.
Definition wrapper.cpp:119
void Waitall(int count, MPI_Request array_of_requests[], MPI_Status *array_of_statuses)
MPI wrapper for MPI_Waitall.
Definition wrapper.cpp:187
void Wait(MPI_Request *request, MPI_Status *status)
MPI wrapper for MPI_Wait.
Definition wrapper.cpp:180
void Type_size(MPI_Datatype type, int *size)
MPI wrapper for MPI_Type_size.
Definition wrapper.cpp:249
const std::unordered_map< std::string, f64 > & get_timers()
return all internal timers
Definition wrapper.cpp:46
void Gather(const void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf, int recvcount, MPI_Datatype recvtype, int root, MPI_Comm comm)
MPI wrapper for MPI_Gather.
Definition wrapper.cpp:326
void File_write(MPI_File fh, const void *buf, int count, MPI_Datatype datatype, MPI_Status *status)
MPI wrapper for MPI_File_write.
Definition wrapper.cpp:264
void Test(MPI_Request *request, int *flag, MPI_Status *status)
MPI wrapper for MPI_Test.
Definition wrapper.cpp:319
void Allgather(const void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf, int recvcount, MPI_Datatype recvtype, MPI_Comm comm)
MPI wrapper for MPI_Allgather.
Definition wrapper.cpp:133
void Isend(const void *buf, int count, MPI_Datatype datatype, int dest, int tag, MPI_Comm comm, MPI_Request *request)
MPI wrapper for MPI_Isend.
Definition wrapper.cpp:85
void File_read_at(MPI_File fh, MPI_Offset offset, void *buf, int count, MPI_Datatype datatype, MPI_Status *status)
MPI wrapper for MPI_File_read_at.
Definition wrapper.cpp:292
void File_write_all(MPI_File fh, const void *buf, int count, MPI_Datatype datatype, MPI_Status *status)
MPI wrapper for MPI_File_write_all.
Definition wrapper.cpp:256