32 auto get_hash_comm_map = [](
const std::vector<u64> &vec) {
34 s.resize(vec.size() *
sizeof(
u64));
35 std::memcpy(s.data(), vec.data(), vec.size() *
sizeof(
u64));
36 auto ret = std::hash<std::string>{}(s);
40 auto check_comm_hash = [](
const std::vector<u64> &vec) {
41 auto hash = get_hash_comm_map(vec);
44 auto max_hash = shamalgs::collective::allreduce_max(hash);
45 auto min_hash = shamalgs::collective::allreduce_min(hash);
47 if (max_hash != min_hash) {
48 std::string msg = shambase::format(
49 "hash mismatch {} != {}, local hash = {}", max_hash, min_hash, hash);
50 logger::err_ln(
"Sparse comm", msg);
57 auto check_payload_size_is_int = [](
u64 bytesz,
const std::vector<u64> &global_comm_ranks) {
58 u64 payload_sz = bytesz;
60 if (payload_sz > std::numeric_limits<i32>::max()) {
62 std::vector<u64> send_sizes;
63 for (
u32 i = 0; i < global_comm_ranks.size(); i++) {
64 u32_2 comm_ranks = sham::unpack32(global_comm_ranks[i]);
67 send_sizes.push_back(payload_sz);
72 "payload size {} is too large for MPI (max i32 is {})\n"
73 "message sizes to send: {}",
75 std::numeric_limits<i32>::max(),
79 return (
i32) payload_sz;
91 auto report_unfinished_requests
93 std::string err_msg =
"";
94 for (
u32 i = 0; i < rqs.size(); i++) {
95 if (!rqs.is_event_ready(i)) {
97 if (rqs_infos[i].is_send) {
98 err_msg += shambase::format(
99 "communication timeout : send {} -> {} tag {} size {}\n",
101 rqs_infos[i].receiver,
105 err_msg += shambase::format(
106 "communication timeout : recv {} -> {} tag {} size {}\n",
108 rqs_infos[i].receiver,
114 std::string msg = shambase::format(
"communication timeout : \n{}", err_msg);
115 logger::err_ln(
"Sparse comm", msg);
116 std::this_thread::sleep_for(std::chrono::seconds(2));
120 auto test_event_completions
121 = [](std::vector<MPI_Request> &rqs, std::vector<rq_info> &rqs_infos) {
127 std::vector<bool> done_map = {};
128 done_map.resize(rqs.size());
129 for (
u32 i = 0; i < rqs.size(); i++) {
133 f64 t_last_print = 0;
138 bool loc_done =
true;
139 for (
u32 i = 0; i < rqs.size(); i++) {
178 = shambase::format(
"Sparse comm : {} / {} done", done_count, rqs.size());
179 logger::warn_ln(
"Sparse comm", msg);
185 std::string err_msg =
"";
186 for (
u32 i = 0; i < rqs.size(); i++) {
189 if (rqs_infos[i].is_send) {
190 err_msg += shambase::format(
191 "communication timeout : send {} -> {} tag {} size {}\n",
193 rqs_infos[i].receiver,
197 err_msg += shambase::format(
198 "communication timeout : recv {} -> {} tag {} size {}\n",
200 rqs_infos[i].receiver,
206 std::string msg = shambase::format(
"communication timeout : \n{}", err_msg);
207 logger::err_ln(
"Sparse comm", msg);
208 std::this_thread::sleep_for(std::chrono::seconds(2));
215auto get_SHAM_SPARSE_COMM_INFLIGHT_LIM = []() {
217 "SHAM_SPARSE_COMM_INFLIGHT_LIM",
"128",
"Maximum number of inflight messages");
221 ret = std::stoull(val);
226 "Invalid value for SHAM_SPARSE_COMM_INFLIGHT_LIM {}, using default value {}",
234const u64 SHAM_SPARSE_COMM_INFLIGHT_LIM = get_SHAM_SPARSE_COMM_INFLIGHT_LIM();
236namespace shamalgs::collective {
237 void sparse_comm_debug_infos(
238 std::shared_ptr<sham::DeviceScheduler> dev_sched,
239 const std::vector<SendPayload> &message_send,
240 std::vector<RecvPayload> &message_recv,
241 const SparseCommTable &comm_table) {
245 const std::vector<u64> &send_vec_comm_ranks = comm_table.local_send_vec_comm_ranks;
246 const std::vector<u64> &global_comm_ranks = comm_table.global_comm_ranks;
249 auto print_comm_mat = [&]() {
253 std::string accum =
"";
256 for (
u32 i = 0; i < global_comm_ranks.size(); i++) {
257 u32_2 comm_ranks = sham::unpack32(global_comm_ranks[i]);
260 accum += shambase::format(
264 message_send[send_idx].payload->get_size());
273 matrix =
"\n" + matrix;
276 logger::raw_ln(
"comm matrix:", matrix);
284 auto show_alloc_state = [&]() {
288 std::string accum = shambase::format(
289 "rank = {} maxmem = {}\n",
300 logger::raw_ln(
"alloc state:", log);
309 void sparse_comm_isend_probe_count_irecv(
310 std::shared_ptr<sham::DeviceScheduler> dev_sched,
311 const std::vector<SendPayload> &message_send,
312 std::vector<RecvPayload> &message_recv,
313 const SparseCommTable &comm_table) {
317 const std::vector<u64> &send_vec_comm_ranks = comm_table.local_send_vec_comm_ranks;
318 const std::vector<u64> &global_comm_ranks = comm_table.global_comm_ranks;
322 std::vector<MPI_Request> rqs;
326 for (
u32 i = 0; i < global_comm_ranks.size(); i++) {
327 u32_2 comm_ranks = sham::unpack32(global_comm_ranks[i]);
331 auto &payload = message_send[send_idx].payload;
333 rqs.push_back(MPI_Request{});
334 u32 rq_index = rqs.size() - 1;
335 auto &rq = rqs[rq_index];
337 int send_sz = check_payload_size_is_int(payload->get_size(), global_comm_ranks);
347 payload->get_ptr(), send_sz, MPI_BYTE, comm_ranks.y(), i, MPI_COMM_WORLD, &rq);
354 for (
u32 i = 0; i < global_comm_ranks.size(); i++) {
355 u32_2 comm_ranks = sham::unpack32(global_comm_ranks[i]);
360 payload.sender_ranks = comm_ranks.x();
362 rqs.push_back(MPI_Request{});
363 u32 rq_index = rqs.size() - 1;
364 auto &rq = rqs[rq_index];
371 payload.payload = std::make_unique<shamcomm::CommunicationBuffer>(cnt, dev_sched);
381 payload.payload->get_ptr(),
389 message_recv.push_back(std::move(payload));
393 std::vector<MPI_Status> st_lst(rqs.size());
397 void sparse_comm_allgather_isend_irecv(
398 std::shared_ptr<sham::DeviceScheduler> dev_sched,
399 const std::vector<SendPayload> &message_send,
400 std::vector<RecvPayload> &message_recv,
401 const SparseCommTable &comm_table) {
405 const std::vector<u64> &send_vec_comm_ranks = comm_table.local_send_vec_comm_ranks;
406 const std::vector<u64> &global_comm_ranks = comm_table.global_comm_ranks;
412 std::vector<int> comm_sizes_loc = {};
413 comm_sizes_loc.resize(message_send.size());
414 for (
u64 i = 0; i < message_send.size(); i++) {
416 = check_payload_size_is_int(message_send[i].payload->get_size(), global_comm_ranks);
420 std::vector<int> comm_sizes = {};
424 for (
u32 i = 0; i < global_comm_ranks.size(); i++) {
425 u32_2 comm_ranks = sham::unpack32(global_comm_ranks[i]);
427 i32 sender = comm_ranks.x();
428 i32 receiver = comm_ranks.y();
432 payload.sender_ranks = sender;
433 i32 cnt = comm_sizes[i];
435 payload.payload = std::make_unique<shamcomm::CommunicationBuffer>(cnt, dev_sched);
437 message_recv.push_back(std::move(payload));
442 std::vector<rq_info> rqs_infos;
452 for (
u32 i = 0; i < global_comm_ranks.size(); i++) {
453 u32_2 comm_ranks = sham::unpack32(global_comm_ranks[i]);
455 i32 sender = comm_ranks.x();
456 i32 receiver = comm_ranks.y();
458 i32 tag = tag_map[sender];
461 bool trigger_check =
false;
465 auto &payload = message_send.at(send_idx).payload;
467 auto &rq = rqs.new_request();
469 rqs_infos.push_back({sender, receiver, payload->get_size(), tag,
true,
false});
471 SHAM_ASSERT(payload->get_size() == comm_sizes_loc[send_idx]);
482 comm_sizes_loc[send_idx],
495 auto &payload = message_recv.at(recv_idx).payload;
497 auto &rq = rqs.new_request();
499 rqs_infos.push_back({sender, receiver,
u64(comm_sizes[i]), tag,
false,
true});
509 payload->get_ptr(), comm_sizes[i], MPI_BYTE, sender, tag, MPI_COMM_WORLD, &rq);
516 u64 in_flight_lim = SHAM_SPARSE_COMM_INFLIGHT_LIM;
517 if (in_flight > in_flight_lim) {
522 f64 last_print_time = 0;
529 report_unfinished_requests(rqs, rqs_infos);
532 if (twait.
elasped_sec() - last_print_time > print_freq) {
535 "too many messages in flight :",
541 in_flight = rqs.remain_count();
542 }
while (in_flight > in_flight_lim);
546 test_event_completions(rqs.requests(), rqs_infos);
Provides a helper class to manage a list of MPI requests.
double f64
Alias for double.
std::uint32_t u32
32 bit unsigned integer
std::uint64_t u64
64 bit unsigned integer
std::int32_t i32
32 bit integer
#define SHAM_ASSERT(x)
Shorthand for SHAM_ASSERT_NAMED without a message.
Class Timer measures the time elapsed since the timer was started.
void end()
Stops the timer and stores the elapsed time in nanoseconds.
f64 elasped_sec() const
Converts the stored nanosecond time to a floating point representation in seconds.
void start()
Starts the timer.
This header file contains utility functions related to exception handling in the code.
std::vector< int > vector_allgatherv(const std::vector< T > &send_vec, const MPI_Datatype &send_type, std::vector< T > &recv_vec, const MPI_Datatype &recv_type, const MPI_Comm comm)
allgatherv on vector with size query (size querying variant of vector_allgatherv_ks) //TODO add fault...
void gather_str(const std::string &send_vec, std::string &recv_vec)
Gathers a string from all nodes and store the result in a std::string.
MemPerfInfos get_mem_perf_info()
Retrieve the memory performance information.
std::string readable_sizeof(double size)
given a sizeof value return a readble string Example : readable_sizeof(1024*1024*1024) -> "1....
void throw_with_loc(std::string message, SourceLocation loc=SourceLocation{})
Throw an exception and append the source location to it.
std::string getenv_str_default_register(const char *env_var, std::string default_val, std::string desc)
Get the content of the environment variable if it exist and register it documentation,...
i32 world_rank()
Gives the rank of the current process in the MPI communicator.
i32 world_size()
Gives the size of the MPI communicator.
Structure to store the performance informations about memory allocation and deallocation.
size_t max_allocated_byte_device
max bytes allocated on the device
Functions related to the MPI communicator.
void Get_count(const MPI_Status *status, MPI_Datatype datatype, int *count)
MPI wrapper for MPI_Get_count.
void Irecv(void *buf, int count, MPI_Datatype datatype, int source, int tag, MPI_Comm comm, MPI_Request *request)
MPI wrapper for MPI_Irecv.
void Probe(int source, int tag, MPI_Comm comm, MPI_Status *status)
MPI wrapper for MPI_Probe.
void Barrier(MPI_Comm comm)
MPI wrapper for MPI_Barrier.
void Waitall(int count, MPI_Request array_of_requests[], MPI_Status *array_of_statuses)
MPI wrapper for MPI_Waitall.
void Test(MPI_Request *request, int *flag, MPI_Status *status)
MPI wrapper for MPI_Test.
void Isend(const void *buf, int count, MPI_Datatype datatype, int dest, int tag, MPI_Comm comm, MPI_Request *request)
MPI wrapper for MPI_Isend.