Shamrock 2025.10.0
Astrophysical Code
Loading...
Searching...
No Matches
sparseXchg.cpp
Go to the documentation of this file.
1// -------------------------------------------------------//
2//
3// SHAMROCK code for hydrodynamics
4// Copyright (c) 2021-2026 Timothée David--Cléris <tim.shamrock@proton.me>
5// SPDX-License-Identifier: CeCILL Free Software License Agreement v2.1
6// Shamrock is licensed under the CeCILL 2.1 License, see LICENSE for more information
7//
8// -------------------------------------------------------//
9
19#include "shambase/string.hpp"
20#include "shambase/time.hpp"
22#include "shamcmdopt/env.hpp"
23#include "shamcomm/logs.hpp"
25#include <stdexcept>
26#include <string>
27#include <thread>
28#include <vector>
29
30namespace {
31
32 auto get_hash_comm_map = [](const std::vector<u64> &vec) {
33 std::string s = "";
34 s.resize(vec.size() * sizeof(u64));
35 std::memcpy(s.data(), vec.data(), vec.size() * sizeof(u64));
36 auto ret = std::hash<std::string>{}(s);
37 return ret;
38 };
39
40 auto check_comm_hash = [](const std::vector<u64> &vec) {
41 auto hash = get_hash_comm_map(vec);
42 // logger::raw_ln("global_comm_ranks hash", hash);
43
44 auto max_hash = shamalgs::collective::allreduce_max(hash);
45 auto min_hash = shamalgs::collective::allreduce_min(hash);
46
47 if (max_hash != min_hash) {
48 std::string msg = shambase::format(
49 "hash mismatch {} != {}, local hash = {}", max_hash, min_hash, hash);
50 logger::err_ln("Sparse comm", msg);
51 shamcomm::mpi::Barrier(MPI_COMM_WORLD);
53 }
54 };
55
56 // Utility lambda for error reporting
57 auto check_payload_size_is_int = [](u64 bytesz, const std::vector<u64> &global_comm_ranks) {
58 u64 payload_sz = bytesz;
59
60 if (payload_sz > std::numeric_limits<i32>::max()) {
61
62 std::vector<u64> send_sizes;
63 for (u32 i = 0; i < global_comm_ranks.size(); i++) {
64 u32_2 comm_ranks = sham::unpack32(global_comm_ranks[i]);
65
66 if (comm_ranks.x() == shamcomm::world_rank()) {
67 send_sizes.push_back(payload_sz);
68 }
69 }
70
72 "payload size {} is too large for MPI (max i32 is {})\n"
73 "message sizes to send: {}",
74 payload_sz,
75 std::numeric_limits<i32>::max(),
76 send_sizes));
77 }
78
79 return (i32) payload_sz;
80 };
81
82 struct rq_info {
83 i32 sender;
84 i32 receiver;
85 u64 size;
86 i32 tag;
87 bool is_send;
88 bool is_recv;
89 };
90
91 auto report_unfinished_requests
92 = [](shamalgs::collective::RequestList &rqs, std::vector<rq_info> &rqs_infos) {
93 std::string err_msg = "";
94 for (u32 i = 0; i < rqs.size(); i++) {
95 if (!rqs.is_event_ready(i)) {
96
97 if (rqs_infos[i].is_send) {
98 err_msg += shambase::format(
99 "communication timeout : send {} -> {} tag {} size {}\n",
100 rqs_infos[i].sender,
101 rqs_infos[i].receiver,
102 rqs_infos[i].tag,
103 rqs_infos[i].size);
104 } else {
105 err_msg += shambase::format(
106 "communication timeout : recv {} -> {} tag {} size {}\n",
107 rqs_infos[i].sender,
108 rqs_infos[i].receiver,
109 rqs_infos[i].tag,
110 rqs_infos[i].size);
111 }
112 }
113 }
114 std::string msg = shambase::format("communication timeout : \n{}", err_msg);
115 logger::err_ln("Sparse comm", msg);
116 std::this_thread::sleep_for(std::chrono::seconds(2));
118 };
119
120 auto test_event_completions
121 = [](std::vector<MPI_Request> &rqs, std::vector<rq_info> &rqs_infos) {
122 shambase::Timer twait;
123 twait.start();
124 f64 timeout_t = 120;
125 f64 freq_print = 10;
126
127 std::vector<bool> done_map = {};
128 done_map.resize(rqs.size());
129 for (u32 i = 0; i < rqs.size(); i++) {
130 done_map[i] = false;
131 }
132
133 f64 t_last_print = 0;
134 u64 done_count = 0;
135
136 bool done = false;
137 while (!done) {
138 bool loc_done = true;
139 for (u32 i = 0; i < rqs.size(); i++) {
140 if (done_map[i]) {
141 continue;
142 }
143
144 auto &rq = rqs[i];
145
146 MPI_Status st;
147 int ready;
148 shamcomm::mpi::Test(&rq, &ready, MPI_STATUS_IGNORE);
149 if (!ready) {
150 loc_done = false;
151 // logger::raw_ln(shambase::format(
152 // "communication pending : send {} -> {} tag {} size {}",
153 // rqs_infos[i].sender,
154 // rqs_infos[i].receiver,
155 // rqs_infos[i].tag,
156 // rqs_infos[i].size));
157 } else {
158 done_map[i] = true;
159 done_count++;
160 // logger::raw_ln(shambase::format(
161 // "communication done : send {} -> {} tag {} size {}",
162 // rqs_infos[i].sender,
163 // rqs_infos[i].receiver,
164 // rqs_infos[i].tag,
165 // rqs_infos[i].size));
166 }
167 }
168
169 if (loc_done) {
170 done = true;
171 }
172
173 twait.end();
174
175 if (twait.elasped_sec() > t_last_print + 10) {
176
177 std::string msg
178 = shambase::format("Sparse comm : {} / {} done", done_count, rqs.size());
179 logger::warn_ln("Sparse comm", msg);
180
181 t_last_print = twait.elasped_sec();
182 }
183
184 if (twait.elasped_sec() > timeout_t) {
185 std::string err_msg = "";
186 for (u32 i = 0; i < rqs.size(); i++) {
187 if (!done_map[i]) {
188
189 if (rqs_infos[i].is_send) {
190 err_msg += shambase::format(
191 "communication timeout : send {} -> {} tag {} size {}\n",
192 rqs_infos[i].sender,
193 rqs_infos[i].receiver,
194 rqs_infos[i].tag,
195 rqs_infos[i].size);
196 } else {
197 err_msg += shambase::format(
198 "communication timeout : recv {} -> {} tag {} size {}\n",
199 rqs_infos[i].sender,
200 rqs_infos[i].receiver,
201 rqs_infos[i].tag,
202 rqs_infos[i].size);
203 }
204 }
205 }
206 std::string msg = shambase::format("communication timeout : \n{}", err_msg);
207 logger::err_ln("Sparse comm", msg);
208 std::this_thread::sleep_for(std::chrono::seconds(2));
210 }
211 }
212 };
213} // namespace
214
215auto get_SHAM_SPARSE_COMM_INFLIGHT_LIM = []() {
217 "SHAM_SPARSE_COMM_INFLIGHT_LIM", "128", "Maximum number of inflight messages");
218
219 u64 ret = 128;
220 try {
221 ret = std::stoull(val);
222 } catch (...) {
223 logger::err_ln(
224 "Sparse comm",
225 shambase::format(
226 "Invalid value for SHAM_SPARSE_COMM_INFLIGHT_LIM {}, using default value {}",
227 val,
228 ret));
229 }
230
231 return ret;
232};
233
234const u64 SHAM_SPARSE_COMM_INFLIGHT_LIM = get_SHAM_SPARSE_COMM_INFLIGHT_LIM();
235
236namespace shamalgs::collective {
237 void sparse_comm_debug_infos(
238 std::shared_ptr<sham::DeviceScheduler> dev_sched,
239 const std::vector<SendPayload> &message_send,
240 std::vector<RecvPayload> &message_recv,
241 const SparseCommTable &comm_table) {
242 StackEntry stack_loc{};
243
244 // share comm list across nodes
245 const std::vector<u64> &send_vec_comm_ranks = comm_table.local_send_vec_comm_ranks;
246 const std::vector<u64> &global_comm_ranks = comm_table.global_comm_ranks;
247
248 // Utility lambda for printing comm matrix
249 auto print_comm_mat = [&]() {
250 StackEntry stack_loc{};
251
252 shamcomm::mpi::Barrier(MPI_COMM_WORLD);
253 std::string accum = "";
254
255 u32 send_idx = 0;
256 for (u32 i = 0; i < global_comm_ranks.size(); i++) {
257 u32_2 comm_ranks = sham::unpack32(global_comm_ranks[i]);
258
259 if (comm_ranks.x() == shamcomm::world_rank()) {
260 accum += shambase::format(
261 "{} # {} # {}\n",
262 comm_ranks.x(),
263 comm_ranks.y(),
264 message_send[send_idx].payload->get_size());
265
266 send_idx++;
267 }
268 }
269
270 std::string matrix;
272
273 matrix = "\n" + matrix;
274
275 if (shamcomm::world_rank() == 0) {
276 logger::raw_ln("comm matrix:", matrix);
277 }
278 shamcomm::mpi::Barrier(MPI_COMM_WORLD);
279 };
280
281 // Enable this only to do debug
282 print_comm_mat();
283
284 auto show_alloc_state = [&]() {
285 StackEntry stack_loc{};
287
288 std::string accum = shambase::format(
289 "rank = {} maxmem = {}\n",
292
293 shamcomm::mpi::Barrier(MPI_COMM_WORLD);
294 std::string log;
296
297 log = "\n" + log;
298
299 if (shamcomm::world_rank() == 0) {
300 logger::raw_ln("alloc state:", log);
301 }
302 shamcomm::mpi::Barrier(MPI_COMM_WORLD);
303 };
304
305 // Enable this only to do debug
306 show_alloc_state();
307 }
308
309 void sparse_comm_isend_probe_count_irecv(
310 std::shared_ptr<sham::DeviceScheduler> dev_sched,
311 const std::vector<SendPayload> &message_send,
312 std::vector<RecvPayload> &message_recv,
313 const SparseCommTable &comm_table) {
314 StackEntry stack_loc{};
315
316 // share comm list across nodes
317 const std::vector<u64> &send_vec_comm_ranks = comm_table.local_send_vec_comm_ranks;
318 const std::vector<u64> &global_comm_ranks = comm_table.global_comm_ranks;
319
320 // note the tag cannot be bigger than max_i32 because of the allgatherv
321
322 std::vector<MPI_Request> rqs;
323
324 // send step
325 u32 send_idx = 0;
326 for (u32 i = 0; i < global_comm_ranks.size(); i++) {
327 u32_2 comm_ranks = sham::unpack32(global_comm_ranks[i]);
328
329 if (comm_ranks.x() == shamcomm::world_rank()) {
330
331 auto &payload = message_send[send_idx].payload;
332
333 rqs.push_back(MPI_Request{});
334 u32 rq_index = rqs.size() - 1;
335 auto &rq = rqs[rq_index];
336
337 int send_sz = check_payload_size_is_int(payload->get_size(), global_comm_ranks);
338
339 // logger::raw_ln(shambase::format(
340 // "[{}] send {} bytes to rank {}, tag {}",
341 // shamcomm::world_rank(),
342 // payload->get_bytesize(),
343 // comm_ranks.y(),
344 // i));
345
347 payload->get_ptr(), send_sz, MPI_BYTE, comm_ranks.y(), i, MPI_COMM_WORLD, &rq);
348
349 send_idx++;
350 }
351 }
352
353 // recv step
354 for (u32 i = 0; i < global_comm_ranks.size(); i++) {
355 u32_2 comm_ranks = sham::unpack32(global_comm_ranks[i]);
356
357 if (comm_ranks.y() == shamcomm::world_rank()) {
358
359 RecvPayload payload;
360 payload.sender_ranks = comm_ranks.x();
361
362 rqs.push_back(MPI_Request{});
363 u32 rq_index = rqs.size() - 1;
364 auto &rq = rqs[rq_index];
365
366 MPI_Status st;
367 i32 cnt;
368 shamcomm::mpi::Probe(comm_ranks.x(), i, MPI_COMM_WORLD, &st);
369 shamcomm::mpi::Get_count(&st, MPI_BYTE, &cnt);
370
371 payload.payload = std::make_unique<shamcomm::CommunicationBuffer>(cnt, dev_sched);
372
373 // logger::raw_ln(shambase::format(
374 // "[{}] recv {} bytes from rank {}, tag {}",
375 // shamcomm::world_rank(),
376 // cnt,
377 // comm_ranks.x(),
378 // i));
379
381 payload.payload->get_ptr(),
382 cnt,
383 MPI_BYTE,
384 comm_ranks.x(),
385 i,
386 MPI_COMM_WORLD,
387 &rq);
388
389 message_recv.push_back(std::move(payload));
390 }
391 }
392
393 std::vector<MPI_Status> st_lst(rqs.size());
394 shamcomm::mpi::Waitall(rqs.size(), rqs.data(), st_lst.data());
395 }
396
397 void sparse_comm_allgather_isend_irecv(
398 std::shared_ptr<sham::DeviceScheduler> dev_sched,
399 const std::vector<SendPayload> &message_send,
400 std::vector<RecvPayload> &message_recv,
401 const SparseCommTable &comm_table) {
402 StackEntry stack_loc{};
403
404 // share comm list across nodes
405 const std::vector<u64> &send_vec_comm_ranks = comm_table.local_send_vec_comm_ranks;
406 const std::vector<u64> &global_comm_ranks = comm_table.global_comm_ranks;
407
408 // check hash
409 // check_comm_hash(global_comm_ranks);
410
411 // Build global comm size table
412 std::vector<int> comm_sizes_loc = {};
413 comm_sizes_loc.resize(message_send.size());
414 for (u64 i = 0; i < message_send.size(); i++) {
415 comm_sizes_loc[i]
416 = check_payload_size_is_int(message_send[i].payload->get_size(), global_comm_ranks);
417 }
418
419 // gather sizes
420 std::vector<int> comm_sizes = {};
421 vector_allgatherv(comm_sizes_loc, comm_sizes, MPI_COMM_WORLD);
422
423 // Init the receiving buffers
424 for (u32 i = 0; i < global_comm_ranks.size(); i++) {
425 u32_2 comm_ranks = sham::unpack32(global_comm_ranks[i]);
426
427 i32 sender = comm_ranks.x();
428 i32 receiver = comm_ranks.y();
429
430 if (receiver == shamcomm::world_rank()) {
431 RecvPayload payload;
432 payload.sender_ranks = sender;
433 i32 cnt = comm_sizes[i];
434
435 payload.payload = std::make_unique<shamcomm::CommunicationBuffer>(cnt, dev_sched);
436
437 message_recv.push_back(std::move(payload));
438 }
439 }
440
441 RequestList rqs;
442 std::vector<rq_info> rqs_infos;
443
444 std::vector<i32> tag_map(shamcomm::world_size(), 0);
445
446 // send step
447 u32 send_idx = 0;
448 u32 recv_idx = 0;
449
450 u32 in_flight = 0;
451
452 for (u32 i = 0; i < global_comm_ranks.size(); i++) {
453 u32_2 comm_ranks = sham::unpack32(global_comm_ranks[i]);
454
455 i32 sender = comm_ranks.x();
456 i32 receiver = comm_ranks.y();
457
458 i32 tag = tag_map[sender];
459 tag_map[sender]++;
460
461 bool trigger_check = false;
462
463 if (sender == shamcomm::world_rank()) {
464
465 auto &payload = message_send.at(send_idx).payload;
466
467 auto &rq = rqs.new_request();
468
469 rqs_infos.push_back({sender, receiver, payload->get_size(), tag, true, false});
470
471 SHAM_ASSERT(payload->get_size() == comm_sizes_loc[send_idx]);
472
473 // logger::raw_ln(shambase::format(
474 // "[{}] send {} bytes to rank {}, tag {}",
475 // shamcomm::world_rank(),
476 // payload->get_bytesize(),
477 // comm_ranks.y(),
478 // i));
479
481 payload->get_ptr(),
482 comm_sizes_loc[send_idx],
483 MPI_BYTE,
484 receiver,
485 tag,
486 MPI_COMM_WORLD,
487 &rq);
488
489 send_idx++;
490 in_flight++;
491 }
492
493 if (receiver == shamcomm::world_rank()) {
494
495 auto &payload = message_recv.at(recv_idx).payload;
496
497 auto &rq = rqs.new_request();
498
499 rqs_infos.push_back({sender, receiver, u64(comm_sizes[i]), tag, false, true});
500
501 // logger::raw_ln(shambase::format(
502 // "[{}] recv {} bytes from rank {}, tag {}",
503 // shamcomm::world_rank(),
504 // cnt,
505 // comm_ranks.x(),
506 // i));
507
509 payload->get_ptr(), comm_sizes[i], MPI_BYTE, sender, tag, MPI_COMM_WORLD, &rq);
510
511 recv_idx++;
512 in_flight++;
513 }
514
515 // routine to limit the number of in-flight messages
516 u64 in_flight_lim = SHAM_SPARSE_COMM_INFLIGHT_LIM;
517 if (in_flight > in_flight_lim) {
518
519 f64 timeout = 120; // seconds
520 f64 print_freq = 10; // seconds
521
522 f64 last_print_time = 0;
523
524 shambase::Timer twait;
525 twait.start();
526 do {
527 twait.end();
528 if (twait.elasped_sec() > timeout) {
529 report_unfinished_requests(rqs, rqs_infos);
530 }
531
532 if (twait.elasped_sec() - last_print_time > print_freq) {
533 logger::warn_ln(
534 "SparseComm",
535 "too many messages in flight :",
536 in_flight,
537 "/",
538 in_flight_lim);
539 last_print_time = twait.elasped_sec();
540 }
541 in_flight = rqs.remain_count();
542 } while (in_flight > in_flight_lim);
543 }
544 }
545
546 test_event_completions(rqs.requests(), rqs_infos);
547
548 rqs.wait_all();
549
550 // logger::raw_ln(tag_map);
551
552 // shamcomm::mpi::Barrier(MPI_COMM_WORLD);
553 // if (shamcomm::world_rank() == 0) {
554 // logger::raw_ln(shambase::format("sparse comm done"));
555 // }
556 // shamcomm::mpi::Barrier(MPI_COMM_WORLD);
557 }
558} // namespace shamalgs::collective
Provides a helper class to manage a list of MPI requests.
double f64
Alias for double.
std::uint32_t u32
32 bit unsigned integer
std::uint64_t u64
64 bit unsigned integer
std::int32_t i32
32 bit integer
#define SHAM_ASSERT(x)
Shorthand for SHAM_ASSERT_NAMED without a message.
Definition assert.hpp:67
Class Timer measures the time elapsed since the timer was started.
Definition time.hpp:96
void end()
Stops the timer and stores the elapsed time in nanoseconds.
Definition time.hpp:111
f64 elasped_sec() const
Converts the stored nanosecond time to a floating point representation in seconds.
Definition time.hpp:123
void start()
Starts the timer.
Definition time.hpp:106
This header file contains utility functions related to exception handling in the code.
std::vector< int > vector_allgatherv(const std::vector< T > &send_vec, const MPI_Datatype &send_type, std::vector< T > &recv_vec, const MPI_Datatype &recv_type, const MPI_Comm comm)
allgatherv on vector with size query (size querying variant of vector_allgatherv_ks) //TODO add fault...
Definition exchanges.hpp:98
void gather_str(const std::string &send_vec, std::string &recv_vec)
Gathers a string from all nodes and store the result in a std::string.
MemPerfInfos get_mem_perf_info()
Retrieve the memory performance information.
std::string readable_sizeof(double size)
given a sizeof value return a readble string Example : readable_sizeof(1024*1024*1024) -> "1....
Definition string.hpp:139
void throw_with_loc(std::string message, SourceLocation loc=SourceLocation{})
Throw an exception and append the source location to it.
std::string getenv_str_default_register(const char *env_var, std::string default_val, std::string desc)
Get the content of the environment variable if it exist and register it documentation,...
Definition env.hpp:88
i32 world_rank()
Gives the rank of the current process in the MPI communicator.
Definition worldInfo.cpp:40
i32 world_size()
Gives the size of the MPI communicator.
Definition worldInfo.cpp:38
Structure to store the performance informations about memory allocation and deallocation.
size_t max_allocated_byte_device
max bytes allocated on the device
Functions related to the MPI communicator.
void Get_count(const MPI_Status *status, MPI_Datatype datatype, int *count)
MPI wrapper for MPI_Get_count.
Definition wrapper.cpp:222
void Irecv(void *buf, int count, MPI_Datatype datatype, int source, int tag, MPI_Comm comm, MPI_Request *request)
MPI wrapper for MPI_Irecv.
Definition wrapper.cpp:102
void Probe(int source, int tag, MPI_Comm comm, MPI_Status *status)
MPI wrapper for MPI_Probe.
Definition wrapper.cpp:201
void Barrier(MPI_Comm comm)
MPI wrapper for MPI_Barrier.
Definition wrapper.cpp:194
void Waitall(int count, MPI_Request array_of_requests[], MPI_Status *array_of_statuses)
MPI wrapper for MPI_Waitall.
Definition wrapper.cpp:187
void Test(MPI_Request *request, int *flag, MPI_Status *status)
MPI wrapper for MPI_Test.
Definition wrapper.cpp:319
void Isend(const void *buf, int count, MPI_Datatype datatype, int dest, int tag, MPI_Comm comm, MPI_Request *request)
MPI wrapper for MPI_Isend.
Definition wrapper.cpp:85