23#include <unordered_map>
28 std::unordered_map<std::string, f64> mpi_timers;
32namespace shamcomm::mpi {
34 mpi_timers[timername] += time;
35 mpi_timers[
"total"] += time;
37 if (shambase::profiling::is_profiling_enabled()) {
38 auto wtime = shambase::details::get_wtime();
39 shambase::profiling::register_counter_val(timername, wtime, mpi_timers[timername]);
40 shambase::profiling::register_counter_val(
"total MPi time", wtime, mpi_timers[
"total"]);
44 f64 get_timer(std::string timername) {
return mpi_timers[timername]; }
46 const std::unordered_map<std::string, f64> &
get_timers() {
return mpi_timers; }
48 std::vector<std::string> possible_keys{
49 "total",
"MPI_Isend",
"MPI_Irecv",
50 "MPI_Allreduce",
"MPI_Allgather",
"MPI_Allgatherv",
51 "MPI_Exscan",
"MPI_Wait",
"MPI_Waitall",
52 "MPI_Barrier",
"MPI_Probe",
"MPI_Recv",
53 "MPI_Get_count",
"MPI_Send",
"MPI_File_set_view",
54 "MPI_Type_size",
"MPI_File_write_all",
"MPI_File_write",
55 "MPI_File_read",
"MPI_File_write_at",
"MPI_File_read_at",
56 "MPI_File_close",
"MPI_File_open",
"MPI_Test",
57 "MPI_Gather",
"MPI_Gatherv",
67 inline void wrap_profiling(std::string timername, Func &&f) {
69 tstart = shambase::details::get_wtime();
71 shamcomm::mpi::register_time(timername, shambase::details::get_wtime() - tstart);
76namespace shamcomm::mpi {
78 void check_tag_value(
i32 tag) {
88 MPI_Datatype datatype,
92 MPI_Request *request) {
97 wrap_profiling(
"MPI_Isend", [&]() {
98 MPICHECK(MPI_Isend(buf, count, datatype, dest, tag, comm, request));
105 MPI_Datatype datatype,
109 MPI_Request *request) {
112 check_tag_value(tag);
114 wrap_profiling(
"MPI_Irecv", [&]() {
115 MPICHECK(MPI_Irecv(buf, count, datatype, source, tag, comm, request));
123 MPI_Datatype datatype,
128 wrap_profiling(
"MPI_Allreduce", [&]() {
129 MPICHECK(MPI_Allreduce(sendbuf, recvbuf, count, datatype, op, comm));
136 MPI_Datatype sendtype,
139 MPI_Datatype recvtype,
143 wrap_profiling(
"MPI_Allgather", [&]() {
145 MPI_Allgather(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, comm));
152 MPI_Datatype sendtype,
154 const int recvcounts[],
156 MPI_Datatype recvtype,
160 wrap_profiling(
"MPI_Allgatherv", [&]() {
162 sendbuf, sendcount, sendtype, recvbuf, recvcounts, displs, recvtype, comm));
170 MPI_Datatype datatype,
175 wrap_profiling(
"MPI_Exscan", [&]() {
176 MPICHECK(MPI_Exscan(sendbuf, recvbuf, count, datatype, op, comm));
180 void Wait(MPI_Request *request, MPI_Status *status) {
182 wrap_profiling(
"MPI_Wait", [&]() {
183 MPICHECK(MPI_Wait(request, status));
187 void Waitall(
int count, MPI_Request array_of_requests[], MPI_Status *array_of_statuses) {
189 wrap_profiling(
"MPI_Waitall", [&]() {
190 MPICHECK(MPI_Waitall(count, array_of_requests, array_of_statuses));
196 wrap_profiling(
"MPI_Barrier", [&]() {
201 void Probe(
int source,
int tag, MPI_Comm comm, MPI_Status *status) {
203 wrap_profiling(
"MPI_Probe", [&]() {
204 MPICHECK(MPI_Probe(source, tag, comm, status));
211 MPI_Datatype datatype,
215 MPI_Status *status) {
217 wrap_profiling(
"MPI_Recv", [&]() {
218 MPICHECK(MPI_Recv(buf, count, datatype, source, tag, comm, status));
222 void Get_count(
const MPI_Status *status, MPI_Datatype datatype,
int *count) {
224 wrap_profiling(
"MPI_Get_count", [&]() {
225 MPICHECK(MPI_Get_count(status, datatype, count));
229 void Send(
const void *buf,
int count, MPI_Datatype datatype,
int dest,
int tag, MPI_Comm comm) {
231 wrap_profiling(
"MPI_Send", [&]() {
232 MPICHECK(MPI_Send(buf, count, datatype, dest, tag, comm));
240 MPI_Datatype filetype,
244 wrap_profiling(
"MPI_File_set_view", [&]() {
245 MPICHECK(MPI_File_set_view(fh, disp, etype, filetype, datarep, info));
251 wrap_profiling(
"MPI_Type_size", [&]() {
252 MPICHECK(MPI_Type_size(type, size));
257 MPI_File fh,
const void *buf,
int count, MPI_Datatype datatype, MPI_Status *status) {
259 wrap_profiling(
"MPI_File_write_all", [&]() {
260 MPICHECK(MPI_File_write_all(fh, buf, count, datatype, status));
265 MPI_File fh,
const void *buf,
int count, MPI_Datatype datatype, MPI_Status *status) {
267 wrap_profiling(
"MPI_File_write", [&]() {
268 MPICHECK(MPI_File_write(fh, buf, count, datatype, status));
272 void File_read(MPI_File fh,
void *buf,
int count, MPI_Datatype datatype, MPI_Status *status) {
274 wrap_profiling(
"MPI_File_read", [&]() {
275 MPICHECK(MPI_File_read(fh, buf, count, datatype, status));
284 MPI_Datatype datatype,
285 MPI_Status *status) {
287 wrap_profiling(
"MPI_File_write_at", [&]() {
288 MPICHECK(MPI_File_write_at(fh, offset, buf, count, datatype, status));
297 MPI_Datatype datatype,
298 MPI_Status *status) {
300 wrap_profiling(
"MPI_File_read_at", [&]() {
301 MPICHECK(MPI_File_read_at(fh, offset, buf, count, datatype, status));
307 wrap_profiling(
"MPI_File_close", [&]() {
312 void File_open(MPI_Comm comm,
const char *filename,
int amode, MPI_Info info, MPI_File *fh) {
314 wrap_profiling(
"MPI_File_open", [&]() {
315 MPICHECK(MPI_File_open(comm, filename, amode, info, fh));
319 void Test(MPI_Request *request,
int *flag, MPI_Status *status) {
321 wrap_profiling(
"MPI_Test", [&]() {
322 MPICHECK(MPI_Test(request, flag, status));
329 MPI_Datatype sendtype,
332 MPI_Datatype recvtype,
336 wrap_profiling(
"MPI_Gather", [&]() {
338 MPI_Gather(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, root, comm));
345 MPI_Datatype sendtype,
347 const int recvcounts[],
349 MPI_Datatype recvtype,
353 wrap_profiling(
"MPI_Gatherv", [&]() {
355 sendbuf, sendcount, sendtype, recvbuf, recvcounts, displs, recvtype, root, comm));
double f64
Alias for double.
std::int32_t i32
32 bit integer
Utility functions for MPI error checking.
#define MPICHECK(mpicall)
Shortcut macro to check MPI return codes.
void throw_with_loc(std::string message, SourceLocation loc=SourceLocation{})
Throw an exception and append the source location to it.
i32 mpi_max_tag_value()
Gets the maximum value of the MPI tag.
This file contains the definition for the stacktrace related functionality.
Functions related to the MPI communicator.
void File_set_view(MPI_File fh, MPI_Offset disp, MPI_Datatype etype, MPI_Datatype filetype, const char *datarep, MPI_Info info)
MPI wrapper for MPI_File_set_view.
void Get_count(const MPI_Status *status, MPI_Datatype datatype, int *count)
MPI wrapper for MPI_Get_count.
void Recv(void *buf, int count, MPI_Datatype datatype, int source, int tag, MPI_Comm comm, MPI_Status *status)
MPI wrapper for MPI_Recv.
void Irecv(void *buf, int count, MPI_Datatype datatype, int source, int tag, MPI_Comm comm, MPI_Request *request)
MPI wrapper for MPI_Irecv.
void Exscan(const void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
MPI wrapper for MPI_Exscan.
void Gatherv(const void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf, const int recvcounts[], const int displs[], MPI_Datatype recvtype, int root, MPI_Comm comm)
MPI wrapper for MPI_Gatherv.
void Probe(int source, int tag, MPI_Comm comm, MPI_Status *status)
MPI wrapper for MPI_Probe.
void Barrier(MPI_Comm comm)
MPI wrapper for MPI_Barrier.
void File_close(MPI_File *fh)
MPI wrapper for MPI_File_close.
void File_read(MPI_File fh, void *buf, int count, MPI_Datatype datatype, MPI_Status *status)
MPI wrapper for MPI_File_read.
void Allgatherv(const void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf, const int recvcounts[], const int displs[], MPI_Datatype recvtype, MPI_Comm comm)
MPI wrapper for MPI_Allgatherv.
const std::vector< std::string > & get_possible_keys()
return all possible keys for the internal timers
void register_time(std::string timername, f64 time)
Register a timer value.
f64 get_timer(std::string timername)
get a timer value
void File_write_at(MPI_File fh, MPI_Offset offset, const void *buf, int count, MPI_Datatype datatype, MPI_Status *status)
MPI wrapper for MPI_File_write_at.
void File_open(MPI_Comm comm, const char *filename, int amode, MPI_Info info, MPI_File *fh)
MPI wrapper for MPI_File_open.
void Allreduce(const void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
MPI wrapper for MPI_Allreduce.
void Waitall(int count, MPI_Request array_of_requests[], MPI_Status *array_of_statuses)
MPI wrapper for MPI_Waitall.
void Wait(MPI_Request *request, MPI_Status *status)
MPI wrapper for MPI_Wait.
void Type_size(MPI_Datatype type, int *size)
MPI wrapper for MPI_Type_size.
const std::unordered_map< std::string, f64 > & get_timers()
return all internal timers
void Gather(const void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf, int recvcount, MPI_Datatype recvtype, int root, MPI_Comm comm)
MPI wrapper for MPI_Gather.
void File_write(MPI_File fh, const void *buf, int count, MPI_Datatype datatype, MPI_Status *status)
MPI wrapper for MPI_File_write.
void Test(MPI_Request *request, int *flag, MPI_Status *status)
MPI wrapper for MPI_Test.
void Allgather(const void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf, int recvcount, MPI_Datatype recvtype, MPI_Comm comm)
MPI wrapper for MPI_Allgather.
void Isend(const void *buf, int count, MPI_Datatype datatype, int dest, int tag, MPI_Comm comm, MPI_Request *request)
MPI wrapper for MPI_Isend.
void File_read_at(MPI_File fh, MPI_Offset offset, void *buf, int count, MPI_Datatype datatype, MPI_Status *status)
MPI wrapper for MPI_File_read_at.
void File_write_all(MPI_File fh, const void *buf, int count, MPI_Datatype datatype, MPI_Status *status)
MPI wrapper for MPI_File_write_all.