Shamrock 2025.10.0
Astrophysical Code
Loading...
Searching...
No Matches
sparse_exchange.cpp
Go to the documentation of this file.
1// -------------------------------------------------------//
2//
3// SHAMROCK code for hydrodynamics
4// Copyright (c) 2021-2026 Timothée David--Cléris <tim.shamrock@proton.me>
5// SPDX-License-Identifier: CeCILL Free Software License Agreement v2.1
6// Shamrock is licensed under the CeCILL 2.1 License, see LICENSE for more information
7//
8// -------------------------------------------------------//
9
19#include "shambase/memory.hpp"
26#include "shambackends/math.hpp"
27#include "shamcomm/mpi.hpp"
29#include <stdexcept>
30namespace shamalgs::collective {
31
32 CommMessageInfo unpack(u64_2 comm_info) {
33 u64 comm_vec = comm_info.x();
34 size_t message_size = comm_info.y();
35 u32_2 comm_ranks = sham::unpack32(comm_vec);
36 u32 sender = comm_ranks.x();
37 u32 receiver = comm_ranks.y();
38
39 if (message_size == 0) {
41 "Message size is 0 for rank {}, sender = {}, receiver = {}",
43 sender,
44 receiver));
45 }
46
47 return CommMessageInfo{
48 message_size,
49 static_cast<i32>(sender),
50 static_cast<i32>(receiver),
51 std::nullopt,
52 std::nullopt,
53 std::nullopt};
54 };
55
57 std::vector<u64_2> fetch_global_message_data(
58 const std::vector<CommMessageInfo> &messages_send) {
59
60 std::vector<u64_2> local_data = std::vector<u64_2>(messages_send.size());
61
62 for (size_t i = 0; i < messages_send.size(); i++) {
63 u32 sender = static_cast<u32>(messages_send[i].rank_sender);
64 u32 receiver = static_cast<u32>(messages_send[i].rank_receiver);
65 size_t message_size = messages_send[i].message_size;
66
67 if (sender != shamcomm::world_rank()) {
69 "You are trying to send a message from a rank that does not posses it\n"
70 " sender = {}, receiver = {}, world_rank = {}",
71 sender,
72 receiver,
74 }
75
76 local_data[i] = u64_2{sham::pack32(sender, receiver), message_size};
77 }
78
79 std::vector<u64_2> global_data;
80 vector_allgatherv(local_data, global_data, MPI_COMM_WORLD);
81
82 return global_data; // there should be return value optimisation here
83 }
84
86 std::vector<CommMessageInfo> decode_all_message(const std::vector<u64_2> &global_data) {
87 std::vector<CommMessageInfo> message_all(global_data.size());
88 for (u64 i = 0; i < global_data.size(); i++) {
89 message_all[i] = unpack(global_data[i]);
90 }
91
92 return message_all;
93 }
94
96 void compute_tags(std::vector<CommMessageInfo> &message_all) {
97
98 std::vector<i32> tag_map(shamcomm::world_size(), 0);
99
100 for (u64 i = 0; i < message_all.size(); i++) {
101 auto &message_info = message_all[i];
102 auto sender = message_info.rank_sender;
103
104 // tagging logic
105 i32 &tag_map_ref = tag_map[static_cast<size_t>(sender)];
106 i32 tag = tag_map_ref;
107 tag_map_ref++;
108
109 message_info.message_tag = tag;
110 }
111 }
112
113 CommTable build_sparse_exchange_table(
114 const std::vector<CommMessageInfo> &messages_send, size_t max_alloc_size) {
116
117 std::vector<u64_2> global_data = fetch_global_message_data(messages_send);
118
119 std::vector<CommMessageInfo> message_all = decode_all_message(global_data);
120
121 compute_tags(message_all);
122
124 // Compute offsets
126
127 std::vector<size_t> send_buf_sizes{};
128 std::vector<size_t> recv_buf_sizes{};
129
130 u32 send_idx = 0;
131 u32 recv_idx = 0;
132 {
133 size_t tmp_recv_offset = 0;
134 size_t tmp_send_offset = 0;
135 size_t send_buf_id = 0;
136 size_t recv_buf_id = 0;
137 for (u64 i = 0; i < message_all.size(); i++) {
138 auto &message_info = message_all[i];
139
140 auto sender = message_info.rank_sender;
141 auto receiver = message_info.rank_receiver;
142
143 // offset logic (& buffer selection)
144 if (sender == shamcomm::world_rank()) {
145 if (message_info.message_size > max_alloc_size) {
147 shambase::format(
148 "Message size is greater than the max alloc size\n"
149 " message_size = {}, max_alloc_size = {}",
150 message_info.message_size,
151 max_alloc_size));
152 }
153
154 if (send_buf_sizes.size() == 0) {
155 send_buf_sizes.push_back(0);
156 }
157
158 if (tmp_send_offset + message_info.message_size >= max_alloc_size) {
159 send_buf_id++;
160 tmp_send_offset = 0;
161 send_buf_sizes.push_back(0);
162 // logger::info_ln("sparse comm", "is using multiple buffers (send) !");
163 }
164
165 message_info.message_bytebuf_offset_send = {send_buf_id, tmp_send_offset};
166 tmp_send_offset += message_info.message_size;
167 send_buf_sizes.at(send_buf_id) += message_info.message_size;
168
169 send_idx++;
170 }
171
172 if (receiver == shamcomm::world_rank()) {
173
174 if (message_info.message_size > max_alloc_size) {
176 shambase::format(
177 "Message size is greater than the max alloc size\n"
178 " message_size = {}, max_alloc_size = {}",
179 message_info.message_size,
180 max_alloc_size));
181 }
182
183 if (recv_buf_sizes.size() == 0) {
184 recv_buf_sizes.push_back(0);
185 }
186
187 if (tmp_recv_offset + message_info.message_size >= max_alloc_size) {
188 recv_buf_id++;
189 tmp_recv_offset = 0;
190 recv_buf_sizes.push_back(0);
191 // logger::info_ln("sparse comm", "is using multiple buffers (recv) !");
192 }
193
194 message_info.message_bytebuf_offset_recv = {recv_buf_id, tmp_recv_offset};
195 tmp_recv_offset += message_info.message_size;
196 recv_buf_sizes.at(recv_buf_id) += message_info.message_size;
197
198 recv_idx++;
199 }
200
201 message_all[i] = message_info;
202 }
203 }
204
205 //{
206 // logger::info_ln("sparse comm", "send_buf_sizes :", send_buf_sizes);
207 // logger::info_ln("sparse comm", "recv_buf_sizes :", recv_buf_sizes);
208 //}
209
211 // now that all comm were computed we can build the send and recv message lists
213
214 std::vector<CommMessageInfo> ret_message_send(send_idx);
215 std::vector<CommMessageInfo> ret_message_recv(recv_idx);
216
217 std::vector<size_t> send_message_global_ids(send_idx);
218 std::vector<size_t> recv_message_global_ids(recv_idx);
219
220 send_idx = 0;
221 recv_idx = 0;
222
223 for (size_t i = 0; i < message_all.size(); i++) {
224 auto message_info = message_all[i];
225 if (message_info.rank_sender == shamcomm::world_rank()) {
226 ret_message_send[send_idx] = message_info;
227 send_message_global_ids[send_idx] = i;
228 send_idx++;
229 }
230 if (message_info.rank_receiver == shamcomm::world_rank()) {
231 ret_message_recv[recv_idx] = message_info;
232 recv_message_global_ids[recv_idx] = i;
233 recv_idx++;
234 }
235 }
236
237 return CommTable{
238 ret_message_send,
239 message_all,
240 ret_message_recv,
241 send_message_global_ids,
242 recv_message_global_ids,
243 send_buf_sizes,
244 recv_buf_sizes};
245 }
246
247 void sparse_exchange(
248 std::shared_ptr<sham::DeviceScheduler> dev_sched,
249 const std::vector<const u8 *> &bytebuffer_send,
250 const std::vector<u8 *> &bytebuffer_recv,
251 const CommTable &comm_table) {
252
254
255 u32 SHAM_SPARSE_COMM_INFLIGHT_LIM = 128; // TODO: use the env variable
256
257 RequestList rqs;
258 for (size_t i = 0; i < comm_table.message_all.size(); i++) {
259
260 auto message_info = comm_table.message_all[i];
261
262 if (message_info.rank_sender == shamcomm::world_rank()) {
263 auto off_info = shambase::get_check_ref(message_info.message_bytebuf_offset_send);
264 auto ptr = bytebuffer_send.at(off_info.buf_id) + off_info.data_offset;
265 auto &rq = rqs.new_request();
267 ptr,
268 shambase::narrow_or_throw<i32>(message_info.message_size),
269 MPI_BYTE,
270 message_info.rank_receiver,
271 shambase::get_check_ref(message_info.message_tag),
272 MPI_COMM_WORLD,
273 &rq);
274 }
275
276 if (message_info.rank_receiver == shamcomm::world_rank()) {
277 auto off_info = shambase::get_check_ref(message_info.message_bytebuf_offset_recv);
278 auto ptr = bytebuffer_recv.at(off_info.buf_id) + off_info.data_offset;
279 auto &rq = rqs.new_request();
281 ptr,
282 shambase::narrow_or_throw<i32>(message_info.message_size),
283 MPI_BYTE,
284 message_info.rank_sender,
285 shambase::get_check_ref(message_info.message_tag),
286 MPI_COMM_WORLD,
287 &rq);
288 }
289
290 rqs.spin_lock_partial_wait(SHAM_SPARSE_COMM_INFLIGHT_LIM, 120, 10);
291 }
292 rqs.wait_all();
293 }
294
295 template<sham::USMKindTarget target>
296 void sparse_exchange(
297 std::shared_ptr<sham::DeviceScheduler> dev_sched,
298 std::vector<std::unique_ptr<sham::DeviceBuffer<u8, target>>> &bytebuffer_send,
299 std::vector<std::unique_ptr<sham::DeviceBuffer<u8, target>>> &bytebuffer_recv,
300 const CommTable &comm_table) {
301
303
304 if (&bytebuffer_send == &bytebuffer_recv) {
306 "In-place sparse_exchange is not supported. Send and receive buffers must be "
307 "distinct.");
308 }
309
310 if (comm_table.send_total_sizes.size() != bytebuffer_send.size()) {
312 "The send total size is greater than the send buffer size\n"
313 " send_total_sizes = {}, send_buffer_size = {}",
314 comm_table.send_total_sizes.size(),
315 bytebuffer_send.size()));
316 }
317
318 if (comm_table.recv_total_sizes.size() != bytebuffer_recv.size()) {
320 "The recv total size is greater than the recv buffer size\n"
321 " recv_total_sizes = {}, recv_buffer_size = {}",
322 comm_table.recv_total_sizes.size(),
323 bytebuffer_recv.size()));
324 }
325
326 for (size_t i = 0; i < comm_table.send_total_sizes.size(); i++) {
327 if (comm_table.send_total_sizes[i] > bytebuffer_send[i]->get_size()) {
329 "The send total size is greater than the send buffer size\n"
330 " send_total_sizes = {}, send_buffer_size = {}, buf_id = {}",
331 comm_table.send_total_sizes[i],
332 bytebuffer_send[i]->get_size(),
333 i));
334 }
335 }
336
337 for (size_t i = 0; i < comm_table.recv_total_sizes.size(); i++) {
338 if (comm_table.recv_total_sizes[i] > bytebuffer_recv[i]->get_size()) {
340 "The recv total size is greater than the recv buffer size\n"
341 " recv_total_sizes = {}, recv_buffer_size = {}, buf_id = {}",
342 comm_table.recv_total_sizes[i],
343 bytebuffer_recv[i]->get_size(),
344 i));
345 }
346 }
347
348 bool direct_gpu_capable = dev_sched->ctx->device->mpi_prop.is_mpi_direct_capable;
349
350 if (!direct_gpu_capable && target == sham::device) {
352 "You are trying to use a device buffer on the device but the device is not "
353 "direct "
354 "GPU capable");
355 }
356
357 std::vector<const u8 *> send_ptrs(bytebuffer_send.size());
358 std::vector<u8 *> recv_ptrs(bytebuffer_recv.size());
359
360 sham::EventList depends_list;
361 for (size_t i = 0; i < bytebuffer_send.size(); i++) {
362 send_ptrs[i]
363 = shambase::get_check_ref(bytebuffer_send[i]).get_read_access(depends_list);
364 }
365
366 for (size_t i = 0; i < bytebuffer_recv.size(); i++) {
367 recv_ptrs[i]
368 = shambase::get_check_ref(bytebuffer_recv[i]).get_write_access(depends_list);
369 }
370 depends_list.wait();
371
372 sparse_exchange(dev_sched, send_ptrs, recv_ptrs, comm_table);
373
374 for (size_t i = 0; i < bytebuffer_send.size(); i++) {
375 shambase::get_check_ref(bytebuffer_send[i]).complete_event_state(sycl::event{});
376 }
377
378 for (size_t i = 0; i < bytebuffer_recv.size(); i++) {
379 shambase::get_check_ref(bytebuffer_recv[i]).complete_event_state(sycl::event{});
380 }
381 }
382
383 // template instantiations
384 template void sparse_exchange<sham::device>(
385 std::shared_ptr<sham::DeviceScheduler> dev_sched,
386 std::vector<std::unique_ptr<sham::DeviceBuffer<u8, sham::device>>> &bytebuffer_send,
387 std::vector<std::unique_ptr<sham::DeviceBuffer<u8, sham::device>>> &bytebuffer_recv,
388 const CommTable &comm_table);
389
390 template void sparse_exchange<sham::host>(
391 std::shared_ptr<sham::DeviceScheduler> dev_sched,
392 std::vector<std::unique_ptr<sham::DeviceBuffer<u8, sham::host>>> &bytebuffer_send,
393 std::vector<std::unique_ptr<sham::DeviceBuffer<u8, sham::host>>> &bytebuffer_recv,
394 const CommTable &comm_table);
395
396} // namespace shamalgs::collective
Provides a helper class to manage a list of MPI requests.
This file contains the declaration of the USMPtrHolder class.
std::uint32_t u32
32 bit unsigned integer
std::uint64_t u64
64 bit unsigned integer
std::int32_t i32
32 bit integer
A buffer allocated in USM (Unified Shared Memory)
Class to manage a list of SYCL events.
Definition EventList.hpp:31
void wait()
Wait for all events in the list to be finished.
Definition EventList.hpp:57
This header file contains utility functions related to exception handling in the code.
std::vector< int > vector_allgatherv(const std::vector< T > &send_vec, const MPI_Datatype &send_type, std::vector< T > &recv_vec, const MPI_Datatype &recv_type, const MPI_Comm comm)
allgatherv on vector with size query (size querying variant of vector_allgatherv_ks) //TODO add fault...
Definition exchanges.hpp:98
Define the fmt formatters for sycl::vec.
Use this header to include MPI properly.
@ device
Device memory.
void throw_with_loc(std::string message, SourceLocation loc=SourceLocation{})
Throw an exception and append the source location to it.
T & get_check_ref(const std::unique_ptr< T > &ptr, SourceLocation loc=SourceLocation())
Takes a std::unique_ptr and returns a reference to the object it holds. It throws a std::runtime_erro...
Definition memory.hpp:110
i32 world_rank()
Gives the rank of the current process in the MPI communicator.
Definition worldInfo.cpp:40
i32 world_size()
Gives the size of the MPI communicator.
Definition worldInfo.cpp:38
Utilities for safe type narrowing conversions.
void compute_tags(std::vector< CommMessageInfo > &message_all)
compute message tags
std::vector< u64_2 > fetch_global_message_data(const std::vector< CommMessageInfo > &messages_send)
fetch u64_2 from global message data
std::vector< CommMessageInfo > decode_all_message(const std::vector< u64_2 > &global_data)
decode message to get message
This file contains the definition for the stacktrace related functionality.
#define __shamrock_stack_entry()
Macro to create a stack entry.
Functions related to the MPI communicator.
void Irecv(void *buf, int count, MPI_Datatype datatype, int source, int tag, MPI_Comm comm, MPI_Request *request)
MPI wrapper for MPI_Irecv.
Definition wrapper.cpp:102
void Isend(const void *buf, int count, MPI_Datatype datatype, int dest, int tag, MPI_Comm comm, MPI_Request *request)
MPI wrapper for MPI_Isend.
Definition wrapper.cpp:85