30namespace shamalgs::collective {
32 CommMessageInfo unpack(u64_2 comm_info) {
33 u64 comm_vec = comm_info.x();
34 size_t message_size = comm_info.y();
35 u32_2 comm_ranks = sham::unpack32(comm_vec);
36 u32 sender = comm_ranks.x();
37 u32 receiver = comm_ranks.y();
39 if (message_size == 0) {
41 "Message size is 0 for rank {}, sender = {}, receiver = {}",
47 return CommMessageInfo{
49 static_cast<i32>(sender),
50 static_cast<i32>(receiver),
58 const std::vector<CommMessageInfo> &messages_send) {
60 std::vector<u64_2> local_data = std::vector<u64_2>(messages_send.size());
62 for (
size_t i = 0; i < messages_send.size(); i++) {
63 u32 sender =
static_cast<u32>(messages_send[i].rank_sender);
64 u32 receiver =
static_cast<u32>(messages_send[i].rank_receiver);
65 size_t message_size = messages_send[i].message_size;
69 "You are trying to send a message from a rank that does not posses it\n"
70 " sender = {}, receiver = {}, world_rank = {}",
76 local_data[i] = u64_2{sham::pack32(sender, receiver), message_size};
79 std::vector<u64_2> global_data;
87 std::vector<CommMessageInfo> message_all(global_data.size());
88 for (
u64 i = 0; i < global_data.size(); i++) {
89 message_all[i] = unpack(global_data[i]);
100 for (
u64 i = 0; i < message_all.size(); i++) {
101 auto &message_info = message_all[i];
102 auto sender = message_info.rank_sender;
105 i32 &tag_map_ref = tag_map[
static_cast<size_t>(sender)];
106 i32 tag = tag_map_ref;
109 message_info.message_tag = tag;
113 CommTable build_sparse_exchange_table(
114 const std::vector<CommMessageInfo> &messages_send,
size_t max_alloc_size) {
117 std::vector<u64_2> global_data = fetch_global_message_data(messages_send);
119 std::vector<CommMessageInfo> message_all = decode_all_message(global_data);
121 compute_tags(message_all);
127 std::vector<size_t> send_buf_sizes{};
128 std::vector<size_t> recv_buf_sizes{};
133 size_t tmp_recv_offset = 0;
134 size_t tmp_send_offset = 0;
135 size_t send_buf_id = 0;
136 size_t recv_buf_id = 0;
137 for (
u64 i = 0; i < message_all.size(); i++) {
138 auto &message_info = message_all[i];
140 auto sender = message_info.rank_sender;
141 auto receiver = message_info.rank_receiver;
145 if (message_info.message_size > max_alloc_size) {
148 "Message size is greater than the max alloc size\n"
149 " message_size = {}, max_alloc_size = {}",
150 message_info.message_size,
154 if (send_buf_sizes.size() == 0) {
155 send_buf_sizes.push_back(0);
158 if (tmp_send_offset + message_info.message_size >= max_alloc_size) {
161 send_buf_sizes.push_back(0);
165 message_info.message_bytebuf_offset_send = {send_buf_id, tmp_send_offset};
166 tmp_send_offset += message_info.message_size;
167 send_buf_sizes.at(send_buf_id) += message_info.message_size;
174 if (message_info.message_size > max_alloc_size) {
177 "Message size is greater than the max alloc size\n"
178 " message_size = {}, max_alloc_size = {}",
179 message_info.message_size,
183 if (recv_buf_sizes.size() == 0) {
184 recv_buf_sizes.push_back(0);
187 if (tmp_recv_offset + message_info.message_size >= max_alloc_size) {
190 recv_buf_sizes.push_back(0);
194 message_info.message_bytebuf_offset_recv = {recv_buf_id, tmp_recv_offset};
195 tmp_recv_offset += message_info.message_size;
196 recv_buf_sizes.at(recv_buf_id) += message_info.message_size;
201 message_all[i] = message_info;
214 std::vector<CommMessageInfo> ret_message_send(send_idx);
215 std::vector<CommMessageInfo> ret_message_recv(recv_idx);
217 std::vector<size_t> send_message_global_ids(send_idx);
218 std::vector<size_t> recv_message_global_ids(recv_idx);
223 for (
size_t i = 0; i < message_all.size(); i++) {
224 auto message_info = message_all[i];
226 ret_message_send[send_idx] = message_info;
227 send_message_global_ids[send_idx] = i;
231 ret_message_recv[recv_idx] = message_info;
232 recv_message_global_ids[recv_idx] = i;
241 send_message_global_ids,
242 recv_message_global_ids,
247 void sparse_exchange(
248 std::shared_ptr<sham::DeviceScheduler> dev_sched,
249 const std::vector<const u8 *> &bytebuffer_send,
250 const std::vector<u8 *> &bytebuffer_recv,
251 const CommTable &comm_table) {
255 u32 SHAM_SPARSE_COMM_INFLIGHT_LIM = 128;
258 for (
size_t i = 0; i < comm_table.message_all.size(); i++) {
260 auto message_info = comm_table.message_all[i];
264 auto ptr = bytebuffer_send.at(off_info.buf_id) + off_info.data_offset;
265 auto &rq = rqs.new_request();
270 message_info.rank_receiver,
278 auto ptr = bytebuffer_recv.at(off_info.buf_id) + off_info.data_offset;
279 auto &rq = rqs.new_request();
284 message_info.rank_sender,
290 rqs.spin_lock_partial_wait(SHAM_SPARSE_COMM_INFLIGHT_LIM, 120, 10);
295 template<sham::USMKindTarget target>
296 void sparse_exchange(
297 std::shared_ptr<sham::DeviceScheduler> dev_sched,
300 const CommTable &comm_table) {
304 if (&bytebuffer_send == &bytebuffer_recv) {
306 "In-place sparse_exchange is not supported. Send and receive buffers must be "
310 if (comm_table.send_total_sizes.size() != bytebuffer_send.size()) {
312 "The send total size is greater than the send buffer size\n"
313 " send_total_sizes = {}, send_buffer_size = {}",
314 comm_table.send_total_sizes.size(),
315 bytebuffer_send.size()));
318 if (comm_table.recv_total_sizes.size() != bytebuffer_recv.size()) {
320 "The recv total size is greater than the recv buffer size\n"
321 " recv_total_sizes = {}, recv_buffer_size = {}",
322 comm_table.recv_total_sizes.size(),
323 bytebuffer_recv.size()));
326 for (
size_t i = 0; i < comm_table.send_total_sizes.size(); i++) {
327 if (comm_table.send_total_sizes[i] > bytebuffer_send[i]->get_size()) {
329 "The send total size is greater than the send buffer size\n"
330 " send_total_sizes = {}, send_buffer_size = {}, buf_id = {}",
331 comm_table.send_total_sizes[i],
332 bytebuffer_send[i]->get_size(),
337 for (
size_t i = 0; i < comm_table.recv_total_sizes.size(); i++) {
338 if (comm_table.recv_total_sizes[i] > bytebuffer_recv[i]->get_size()) {
340 "The recv total size is greater than the recv buffer size\n"
341 " recv_total_sizes = {}, recv_buffer_size = {}, buf_id = {}",
342 comm_table.recv_total_sizes[i],
343 bytebuffer_recv[i]->get_size(),
348 bool direct_gpu_capable = dev_sched->ctx->device->mpi_prop.is_mpi_direct_capable;
352 "You are trying to use a device buffer on the device but the device is not "
357 std::vector<const u8 *> send_ptrs(bytebuffer_send.size());
358 std::vector<u8 *> recv_ptrs(bytebuffer_recv.size());
361 for (
size_t i = 0; i < bytebuffer_send.size(); i++) {
366 for (
size_t i = 0; i < bytebuffer_recv.size(); i++) {
372 sparse_exchange(dev_sched, send_ptrs, recv_ptrs, comm_table);
374 for (
size_t i = 0; i < bytebuffer_send.size(); i++) {
378 for (
size_t i = 0; i < bytebuffer_recv.size(); i++) {
384 template void sparse_exchange<sham::device>(
385 std::shared_ptr<sham::DeviceScheduler> dev_sched,
388 const CommTable &comm_table);
390 template void sparse_exchange<sham::host>(
391 std::shared_ptr<sham::DeviceScheduler> dev_sched,
394 const CommTable &comm_table);
Provides a helper class to manage a list of MPI requests.
This file contains the declaration of the USMPtrHolder class.
std::uint32_t u32
32 bit unsigned integer
std::uint64_t u64
64 bit unsigned integer
std::int32_t i32
32 bit integer
A buffer allocated in USM (Unified Shared Memory)
Class to manage a list of SYCL events.
void wait()
Wait for all events in the list to be finished.
This header file contains utility functions related to exception handling in the code.
std::vector< int > vector_allgatherv(const std::vector< T > &send_vec, const MPI_Datatype &send_type, std::vector< T > &recv_vec, const MPI_Datatype &recv_type, const MPI_Comm comm)
allgatherv on vector with size query (size querying variant of vector_allgatherv_ks) //TODO add fault...
Define the fmt formatters for sycl::vec.
Use this header to include MPI properly.
void throw_with_loc(std::string message, SourceLocation loc=SourceLocation{})
Throw an exception and append the source location to it.
T & get_check_ref(const std::unique_ptr< T > &ptr, SourceLocation loc=SourceLocation())
Takes a std::unique_ptr and returns a reference to the object it holds. It throws a std::runtime_erro...
i32 world_rank()
Gives the rank of the current process in the MPI communicator.
i32 world_size()
Gives the size of the MPI communicator.
Utilities for safe type narrowing conversions.
void compute_tags(std::vector< CommMessageInfo > &message_all)
compute message tags
std::vector< u64_2 > fetch_global_message_data(const std::vector< CommMessageInfo > &messages_send)
fetch u64_2 from global message data
std::vector< CommMessageInfo > decode_all_message(const std::vector< u64_2 > &global_data)
decode message to get message
This file contains the definition for the stacktrace related functionality.
#define __shamrock_stack_entry()
Macro to create a stack entry.
Functions related to the MPI communicator.
void Irecv(void *buf, int count, MPI_Datatype datatype, int source, int tag, MPI_Comm comm, MPI_Request *request)
MPI wrapper for MPI_Irecv.
void Isend(const void *buf, int count, MPI_Datatype datatype, int dest, int tag, MPI_Comm comm, MPI_Request *request)
MPI wrapper for MPI_Isend.