32 auto get_hash_comm_map = [](
const std::vector<u64> &vec) {
34 s.resize(vec.size() *
sizeof(
u64));
35 std::memcpy(s.data(), vec.data(), vec.size() *
sizeof(
u64));
36 auto ret = std::hash<std::string>{}(s);
40 auto check_comm_hash = [](
const std::vector<u64> &vec) {
41 auto hash = get_hash_comm_map(vec);
44 auto max_hash = shamalgs::collective::allreduce_max(hash);
45 auto min_hash = shamalgs::collective::allreduce_min(hash);
47 if (max_hash != min_hash) {
48 std::string msg = shambase::format(
49 "hash mismatch {} != {}, local hash = {}", max_hash, min_hash, hash);
57 auto check_payload_size_is_int = [](
u64 bytesz,
const std::vector<u64> &global_comm_ranks) {
58 u64 payload_sz = bytesz;
60 if (payload_sz > std::numeric_limits<i32>::max()) {
62 std::vector<u64> send_sizes;
63 for (
u32 i = 0; i < global_comm_ranks.size(); i++) {
64 u32_2 comm_ranks = sham::unpack32(global_comm_ranks[i]);
67 send_sizes.push_back(payload_sz);
72 "payload size {} is too large for MPI (max i32 is {})\n"
73 "message sizes to send: {}",
75 std::numeric_limits<i32>::max(),
79 return (
i32) payload_sz;
91 auto report_unfinished_requests
93 std::string err_msg =
"";
94 for (
u32 i = 0; i < rqs.size(); i++) {
95 if (!rqs.is_event_ready(i)) {
97 if (rqs_infos[i].is_send) {
98 err_msg += shambase::format(
99 "communication timeout : send {} -> {} tag {} size {}\n",
101 rqs_infos[i].receiver,
105 err_msg += shambase::format(
106 "communication timeout : recv {} -> {} tag {} size {}\n",
108 rqs_infos[i].receiver,
114 std::string msg = shambase::format(
"communication timeout : \n{}", err_msg);
116 std::this_thread::sleep_for(std::chrono::seconds(2));
120 auto test_event_completions
121 = [](std::vector<MPI_Request> &rqs, std::vector<rq_info> &rqs_infos) {
127 std::vector<bool> done_map = {};
128 done_map.resize(rqs.size());
129 for (
u32 i = 0; i < rqs.size(); i++) {
133 f64 t_last_print = 0;
138 bool loc_done =
true;
139 for (
u32 i = 0; i < rqs.size(); i++) {
178 = shambase::format(
"Sparse comm : {} / {} done", done_count, rqs.size());
185 std::string err_msg =
"";
186 for (
u32 i = 0; i < rqs.size(); i++) {
189 if (rqs_infos[i].is_send) {
190 err_msg += shambase::format(
191 "communication timeout : send {} -> {} tag {} size {}\n",
193 rqs_infos[i].receiver,
197 err_msg += shambase::format(
198 "communication timeout : recv {} -> {} tag {} size {}\n",
200 rqs_infos[i].receiver,
206 std::string msg = shambase::format(
"communication timeout : \n{}", err_msg);
208 std::this_thread::sleep_for(std::chrono::seconds(2));
215auto get_SHAM_SPARSE_COMM_INFLIGHT_LIM = []() {
217 "SHAM_SPARSE_COMM_INFLIGHT_LIM",
"128",
"Maximum number of inflight messages");
221 ret = std::stoull(val);
226 "Invalid value for SHAM_SPARSE_COMM_INFLIGHT_LIM {}, using default value {}",
234const u64 SHAM_SPARSE_COMM_INFLIGHT_LIM = get_SHAM_SPARSE_COMM_INFLIGHT_LIM();
236namespace shamalgs::collective {
237 void sparse_comm_debug_infos(
238 std::shared_ptr<sham::DeviceScheduler> dev_sched,
239 const std::vector<SendPayload> &message_send,
240 std::vector<RecvPayload> &message_recv,
245 const std::vector<u64> &send_vec_comm_ranks = comm_table.local_send_vec_comm_ranks;
246 const std::vector<u64> &global_comm_ranks = comm_table.global_comm_ranks;
249 auto print_comm_mat = [&]() {
253 std::string accum =
"";
256 for (
u32 i = 0; i < global_comm_ranks.size(); i++) {
257 u32_2 comm_ranks = sham::unpack32(global_comm_ranks[i]);
260 accum += shambase::format(
264 message_send[send_idx].payload->get_size());
273 matrix =
"\n" + matrix;
276 logger::raw_ln(
"comm matrix:", matrix);
284 auto show_alloc_state = [&]() {
288 std::string accum = shambase::format(
289 "rank = {} maxmem = {}\n",
300 logger::raw_ln(
"alloc state:", log);
309 void sparse_comm_isend_probe_count_irecv(
310 std::shared_ptr<sham::DeviceScheduler> dev_sched,
311 const std::vector<SendPayload> &message_send,
312 std::vector<RecvPayload> &message_recv,
317 const std::vector<u64> &send_vec_comm_ranks = comm_table.local_send_vec_comm_ranks;
318 const std::vector<u64> &global_comm_ranks = comm_table.global_comm_ranks;
322 std::vector<MPI_Request> rqs;
326 for (
u32 i = 0; i < global_comm_ranks.size(); i++) {
327 u32_2 comm_ranks = sham::unpack32(global_comm_ranks[i]);
331 auto &payload = message_send[send_idx].payload;
333 rqs.push_back(MPI_Request{});
334 u32 rq_index = rqs.size() - 1;
335 auto &rq = rqs[rq_index];
337 int send_sz = check_payload_size_is_int(payload->get_size(), global_comm_ranks);
347 payload->get_ptr(), send_sz, MPI_BYTE, comm_ranks.y(), i, MPI_COMM_WORLD, &rq);
354 for (
u32 i = 0; i < global_comm_ranks.size(); i++) {
355 u32_2 comm_ranks = sham::unpack32(global_comm_ranks[i]);
360 payload.sender_ranks = comm_ranks.x();
362 rqs.push_back(MPI_Request{});
363 u32 rq_index = rqs.size() - 1;
364 auto &rq = rqs[rq_index];
371 payload.payload = std::make_unique<shamcomm::CommunicationBuffer>(cnt, dev_sched);
381 payload.payload->get_ptr(),
389 message_recv.push_back(std::move(payload));
393 std::vector<MPI_Status> st_lst(rqs.size());
397 void sparse_comm_allgather_isend_irecv(
398 std::shared_ptr<sham::DeviceScheduler> dev_sched,
399 const std::vector<SendPayload> &message_send,
400 std::vector<RecvPayload> &message_recv,
405 const std::vector<u64> &send_vec_comm_ranks = comm_table.local_send_vec_comm_ranks;
406 const std::vector<u64> &global_comm_ranks = comm_table.global_comm_ranks;
412 std::vector<int> comm_sizes_loc = {};
413 comm_sizes_loc.resize(message_send.size());
414 for (
u64 i = 0; i < message_send.size(); i++) {
416 = check_payload_size_is_int(message_send[i].payload->get_size(), global_comm_ranks);
420 std::vector<int> comm_sizes = {};
424 for (
u32 i = 0; i < global_comm_ranks.size(); i++) {
425 u32_2 comm_ranks = sham::unpack32(global_comm_ranks[i]);
427 i32 sender = comm_ranks.x();
428 i32 receiver = comm_ranks.y();
432 payload.sender_ranks = sender;
433 i32 cnt = comm_sizes[i];
435 payload.payload = std::make_unique<shamcomm::CommunicationBuffer>(cnt, dev_sched);
437 message_recv.push_back(std::move(payload));
442 std::vector<rq_info> rqs_infos;
452 for (
u32 i = 0; i < global_comm_ranks.size(); i++) {
453 u32_2 comm_ranks = sham::unpack32(global_comm_ranks[i]);
455 i32 sender = comm_ranks.x();
456 i32 receiver = comm_ranks.y();
458 i32 tag = tag_map[sender];
461 bool trigger_check =
false;
465 auto &payload = message_send.at(send_idx).payload;
467 auto &rq = rqs.new_request();
471 .receiver = receiver,
472 .size = payload->get_size(),
477 SHAM_ASSERT(payload->get_size() == comm_sizes_loc[send_idx]);
488 comm_sizes_loc[send_idx],
501 auto &payload = message_recv.at(recv_idx).payload;
503 auto &rq = rqs.new_request();
507 .receiver = receiver,
508 .size =
u64(comm_sizes[i]),
521 payload->get_ptr(), comm_sizes[i], MPI_BYTE, sender, tag, MPI_COMM_WORLD, &rq);
528 u64 in_flight_lim = SHAM_SPARSE_COMM_INFLIGHT_LIM;
529 if (in_flight > in_flight_lim) {
534 f64 last_print_time = 0;
536 shambase::Timer twait;
541 report_unfinished_requests(rqs, rqs_infos);
544 if (twait.
elapsed_sec() - last_print_time > print_freq) {
547 "too many messages in flight :",
553 in_flight = rqs.remain_count();
554 }
while (in_flight > in_flight_lim);
558 test_event_completions(rqs.requests(), rqs_infos);
Provides a helper class to manage a list of MPI requests.
double f64
Alias for double.
std::uint32_t u32
32 bit unsigned integer
std::uint64_t u64
64 bit unsigned integer
std::int32_t i32
32 bit integer
#define SHAM_ASSERT(x)
Shorthand for SHAM_ASSERT_NAMED without a message.
Class Timer measures the time elapsed since the timer was started.
f64 elapsed_sec() const
Converts the stored nanosecond time to a floating point representation in seconds.
void start()
Starts the timer.
void stop()
Stops the timer and stores the elapsed time in nanoseconds.
This header file contains utility functions related to exception handling in the code.
std::vector< int > vector_allgatherv(const std::vector< T > &send_vec, const MPI_Datatype &send_type, std::vector< T > &recv_vec, const MPI_Datatype &recv_type, const MPI_Comm comm)
allgatherv on vector with size query (size querying variant of vector_allgatherv_ks) //TODO add fault...
void gather_str(const std::string &send_vec, std::string &recv_vec)
Gathers a string from all nodes and store the result in a std::string.
MemPerfInfos get_mem_perf_info()
Retrieve the memory performance information.
std::string readable_sizeof(double size)
given a sizeof value return a readble string Example : readable_sizeof(1e9) -> "1....
void throw_with_loc(std::string message, SourceLocation loc=SourceLocation{})
Throw an exception and append the source location to it.
std::string getenv_str_default_register(const char *env_var, std::string default_val, std::string desc)
Get the content of the environment variable if it exist and register it documentation,...
i32 world_rank()
Gives the rank of the current process in the MPI communicator.
i32 world_size()
Gives the size of the MPI communicator.
void warn_ln(std::string module_name, Types... var2)
Prints a log message with multiple arguments followed by a newline.
void err_ln(std::string module_name, Types... var2)
Prints a log message with multiple arguments followed by a newline.
shambase::details::BasicStackEntry StackEntry
Alias for shambase::details::BasicStackEntry.
size_t max_allocated_byte_device
max bytes allocated on the device
Functions related to the MPI communicator.
void Get_count(const MPI_Status *status, MPI_Datatype datatype, int *count)
MPI wrapper for MPI_Get_count.
void Irecv(void *buf, int count, MPI_Datatype datatype, int source, int tag, MPI_Comm comm, MPI_Request *request)
MPI wrapper for MPI_Irecv.
void Probe(int source, int tag, MPI_Comm comm, MPI_Status *status)
MPI wrapper for MPI_Probe.
void Barrier(MPI_Comm comm)
MPI wrapper for MPI_Barrier.
void Waitall(int count, MPI_Request array_of_requests[], MPI_Status *array_of_statuses)
MPI wrapper for MPI_Waitall.
void Test(MPI_Request *request, int *flag, MPI_Status *status)
MPI wrapper for MPI_Test.
void Isend(const void *buf, int count, MPI_Datatype datatype, int dest, int tag, MPI_Comm comm, MPI_Request *request)
MPI wrapper for MPI_Isend.