Shamrock 2025.10.0
Astrophysical Code
Loading...
Searching...
No Matches
distributedDataComm.cpp
Go to the documentation of this file.
1// -------------------------------------------------------//
2//
3// SHAMROCK code for hydrodynamics
4// Copyright (c) 2021-2026 Timothée David--Cléris <tim.shamrock@proton.me>
5// SPDX-License-Identifier: CeCILL Free Software License Agreement v2.1
6// Shamrock is licensed under the CeCILL 2.1 License, see LICENSE for more information
7//
8// -------------------------------------------------------//
9
19#include "shambase/memory.hpp"
25#include "shamcmdopt/env.hpp"
26#include <memory>
27#include <vector>
28
29auto SPARSE_COMM_MODE = shamcmdopt::getenv_str_default_register(
30 "SPARSE_COMM_MODE", "new", "Sparse communication mode (new=with cache, old=without cache)");
31
32namespace {
33 struct SparseCommMode {
34 enum Mode { NEW, OLD };
35 };
36
37 constexpr auto parse_sparse_comm_mode = []() {
38 if (SPARSE_COMM_MODE == "new") {
39 return SparseCommMode::NEW;
40 } else if (SPARSE_COMM_MODE == "old") {
41 return SparseCommMode::OLD;
42 } else {
43 throw std::invalid_argument(
44 "Invalid sparse communication mode, valid modes are: new, old");
45 }
46 };
47
48 bool use_old_sparse_comm_mode = parse_sparse_comm_mode() == SparseCommMode::OLD;
49
50 bool warning_printed = false;
51} // namespace
52
53namespace shamalgs::collective {
54
55 namespace details {
56 struct DataTmp {
57 u64 sender;
58 u64 receiver;
59 u64 length;
61
62 SerializeSize get_ser_sz() {
63 return SerializeHelper::serialize_byte_size<u64>() * 3
64 + SerializeHelper::serialize_byte_size<u8>(length);
65 }
66 };
67
68 auto serialize_group_data(
69 std::shared_ptr<sham::DeviceScheduler> dev_sched,
70 std::map<std::pair<i32, i32>, std::vector<DataTmp>> &send_data)
71 -> std::map<std::pair<i32, i32>, SerializeHelper> {
72
73 StackEntry stack_loc{};
74
75 std::map<std::pair<i32, i32>, SerializeHelper> serializers;
76
77 for (auto &[key, vect] : send_data) {
78 SerializeSize byte_sz = SerializeHelper::serialize_byte_size<u64>(); // vec length
79 for (DataTmp &d : vect) {
80 byte_sz += d.get_ser_sz();
81 }
82 serializers.emplace(key, dev_sched);
83 serializers.at(key).allocate(byte_sz);
84 }
85
86 for (auto &[key, vect] : send_data) {
87 SerializeHelper &ser = serializers.at(key);
88 ser.write<u64>(vect.size());
89 for (DataTmp &d : vect) {
90 ser.write(d.sender);
91 ser.write(d.receiver);
92 ser.write(d.length);
93 ser.write_buf(d.data, d.length);
94 }
95 }
96
97 return serializers;
98 }
99
101 public:
102 i32 sender_rank, receiver_rank;
103 SerializeSize sz;
104 std::vector<std::reference_wrapper<DataTmp>> sources;
105 std::unique_ptr<SerializeHelper> serializer = {};
106 std::unique_ptr<sham::DeviceBuffer<u8>> send_buf = {};
107
108 void allocate_serializer(std::shared_ptr<sham::DeviceScheduler> dev_sched) {
109 serializer = std::make_unique<SerializeHelper>(dev_sched);
110 serializer->allocate(sz);
111 }
112
113 void write_sources() {
114 SerializeHelper &ser = shambase::get_check_ref(serializer);
115 ser.write<u64>(sources.size());
116 for (DataTmp &d : sources) {
117 ser.write(d.sender);
118 ser.write(d.receiver);
119 ser.write(d.length);
120 ser.write_buf(d.data, d.length);
121 }
122 }
123
124 void finalize_serializer() {
125 SerializeHelper &ser = shambase::get_check_ref(serializer);
126 send_buf = std::make_unique<sham::DeviceBuffer<u8>>(ser.finalize());
127 }
128 };
129
130 auto serialize_group_data_max_size(
131 std::shared_ptr<sham::DeviceScheduler> dev_sched,
132 std::map<std::pair<i32, i32>, std::vector<DataTmp>> &send_data,
133 u64 max_comm_size) -> std::vector<PrepareCommUtil> {
134
135 StackEntry stack_loc{};
136
137 std::vector<PrepareCommUtil> ret;
138
139 auto add_to_ret = [&](std::pair<i32, i32> key,
140 SerializeSize &byte_sz,
141 std::vector<std::reference_wrapper<DataTmp>> &sources) {
142 if (byte_sz.get_total_size() > max_comm_size) {
144 shambase::format("comm size too large: {}", byte_sz.get_total_size()));
145 }
146
147 auto [sender_rank, receiver_rank] = key;
148
149 if (sources.size() > 0) {
150 PrepareCommUtil next{sender_rank, receiver_rank, byte_sz, sources};
151 ret.push_back(std::move(next));
152 }
153
154 byte_sz = SerializeHelper::serialize_byte_size<u64>(); // vec length
155 sources = {};
156 };
157
158 for (auto &[key, vect] : send_data) {
159 SerializeSize byte_sz = SerializeHelper::serialize_byte_size<u64>(); // vec length
160 std::vector<std::reference_wrapper<DataTmp>> sources = {};
161
162 for (DataTmp &d : vect) {
163 std::reference_wrapper<DataTmp> d_ref = d;
164 auto dbyte_sz = d.get_ser_sz();
165
166 if ((dbyte_sz + byte_sz).get_total_size() > max_comm_size) {
167 add_to_ret(key, byte_sz, sources);
168 // logger::raw_ln("comm split at", d.sender, d.receiver, d.length);
169 }
170
171 // logger::raw_ln(
172 // "add to sources", dbyte_sz.get_total_size(), byte_sz.get_total_size());
173
174 byte_sz += d.get_ser_sz();
175 sources.push_back(d_ref);
176 }
177
178 add_to_ret(key, byte_sz, sources);
179 }
180
181 for (auto &c : ret) {
182 // logger::raw_ln(
183 // "allocate serializer", c.sender_rank, c.receiver_rank,
184 // c.sz.get_total_size());
185 c.allocate_serializer(dev_sched);
186 }
187
188 for (auto &c : ret) {
189 c.write_sources();
190 }
191
192 for (auto &c : ret) {
193 c.finalize_serializer();
194 }
195
196 return ret;
197 }
198
199 } // namespace details
200
201 void distributed_data_sparse_comm_old(
202 sham::DeviceScheduler_ptr dev_sched,
203 SerializedDDataComm &send_distrib_data,
204 SerializedDDataComm &recv_distrib_data,
205 std::function<i32(u64)> rank_getter,
206 std::optional<SparseCommTable> comm_table) {
207
208 StackEntry stack_loc{};
209
210 using namespace shambase;
211 using DataTmp = details::DataTmp;
212
213 // prepare map
214 std::map<std::pair<i32, i32>, std::vector<DataTmp>> send_data;
215 send_distrib_data.for_each([&](u64 sender, u64 receiver, sham::DeviceBuffer<u8> &buf) {
216 std::pair<i32, i32> key = {rank_getter(sender), rank_getter(receiver)};
217
218 send_data[key].push_back(DataTmp{sender, receiver, buf.get_size(), buf});
219 });
220
221 // serialize together similar communications
222 std::map<std::pair<i32, i32>, SerializeHelper> serializers
223 = details::serialize_group_data(dev_sched, send_data);
224
225 // recover bufs from serializers
226 std::map<std::pair<i32, i32>, std::unique_ptr<sham::DeviceBuffer<u8>>> send_bufs;
227 {
228 NamedStackEntry stack_loc2{"recover bufs"};
229 for (auto &[key, ser] : serializers) {
230 send_bufs[key] = std::make_unique<sham::DeviceBuffer<u8>>(ser.finalize());
231 }
232 }
233
234 // prepare payload
235 std::vector<SendPayload> send_payoad;
236 {
237 NamedStackEntry stack_loc2{"prepare payload"};
238 for (auto &[key, buf] : send_bufs) {
239 send_payoad.push_back(
240 {key.second,
241 std::make_unique<shamcomm::CommunicationBuffer>(
242 shambase::extract_pointer(buf), dev_sched)});
243 }
244 }
245
246 // sparse comm
247 std::vector<RecvPayload> recv_payload;
248
249 if (comm_table) {
250 sparse_comm_c(dev_sched, send_payoad, recv_payload, *comm_table);
251 } else {
252 base_sparse_comm(dev_sched, send_payoad, recv_payload);
253 }
254
255 // make serializers from recv buffs
256 struct RecvPayloadSer {
257 i32 sender_ranks;
258 SerializeHelper ser;
259 };
260
261 std::vector<RecvPayloadSer> recv_payload_bufs;
262
263 {
264 NamedStackEntry stack_loc2{"move payloads"};
265 for (RecvPayload &payload : recv_payload) {
266
267 shamcomm::CommunicationBuffer comm_buf = extract_pointer(payload.payload);
268
270 = shamcomm::CommunicationBuffer::convert_usm(std::move(comm_buf));
271
272 recv_payload_bufs.push_back(
273 RecvPayloadSer{
274 payload.sender_ranks, SerializeHelper(dev_sched, std::move(buf))});
275 }
276 }
277
278 {
279 NamedStackEntry stack_loc2{"split recv comms"};
280 // deserialize into the shared distributed data
281 for (RecvPayloadSer &recv : recv_payload_bufs) {
282 u64 cnt_obj;
283 recv.ser.load(cnt_obj);
284 for (u32 i = 0; i < cnt_obj; i++) {
285 u64 sender, receiver, length;
286
287 recv.ser.load(sender);
288 recv.ser.load(receiver);
289 recv.ser.load(length);
290
291 { // check correctness ranks
292 i32 supposed_sender_rank = rank_getter(sender);
293 i32 real_sender_rank = recv.sender_ranks;
294 if (supposed_sender_rank != real_sender_rank) {
295 throw make_except_with_loc<std::runtime_error>(
296 "the rank do not matches");
297 }
298 }
299
300 auto it = recv_distrib_data.add_obj(
301 sender, receiver, sham::DeviceBuffer<u8>(length, dev_sched));
302
303 recv.ser.load_buf(it->second, length);
304 }
305 }
306 }
307 }
308
309 void distributed_data_sparse_comm(
310 sham::DeviceScheduler_ptr dev_sched,
311 SerializedDDataComm &send_distrib_data,
312 SerializedDDataComm &recv_distrib_data,
313 std::function<i32(u64)> rank_getter,
314 DDSCommCache &cache,
315 std::optional<SparseCommTable> comm_table,
316 size_t max_comm_size) {
317
318 if (use_old_sparse_comm_mode) {
319 if (shamcomm::world_rank() == 0 && !warning_printed) {
320 logger::warn_ln("SparseComm", "using old sparse communication mode");
321 warning_printed = true;
322 }
323 return distributed_data_sparse_comm_old(
324 dev_sched, send_distrib_data, recv_distrib_data, rank_getter, comm_table);
325 }
326
328
329 using namespace shambase;
330 using DataTmp = details::DataTmp;
331
332 size_t max_alloc_size;
333 if (dev_sched->ctx->device->mpi_prop.is_mpi_direct_capable) {
334 max_alloc_size = dev_sched->ctx->device->prop.max_mem_alloc_size_dev;
335 } else {
336 max_alloc_size = dev_sched->ctx->device->prop.max_mem_alloc_size_host;
337 }
338 max_alloc_size -= 1; // keep a bit of space for safety
339
340 if (max_alloc_size > max_comm_size) {
341 max_alloc_size = max_comm_size;
342 }
343
344 // prepare map
345 std::map<std::pair<i32, i32>, std::vector<DataTmp>> send_data;
346 send_distrib_data.for_each([&](u64 sender, u64 receiver, sham::DeviceBuffer<u8> &buf) {
347 std::pair<i32, i32> key = {rank_getter(sender), rank_getter(receiver)};
348
349 send_data[key].push_back(DataTmp{sender, receiver, buf.get_size(), buf});
350 });
351
352 std::vector<details::PrepareCommUtil> prepared_comms
353 = details::serialize_group_data_max_size(dev_sched, send_data, max_comm_size);
354
355 std::vector<shamalgs::collective::CommMessageInfo> messages_send;
356 std::vector<std::unique_ptr<sham::DeviceBuffer<u8>>> data_send;
357
358 for (auto &cms : prepared_comms) {
359
360 auto sender = cms.sender_rank;
361 auto receiver = cms.receiver_rank;
362 auto size = shambase::get_check_ref(cms.send_buf).get_size();
363
364 messages_send.push_back(
366 size,
367 sender,
368 receiver,
369 std::nullopt,
370 std::nullopt,
371 std::nullopt,
372 });
373
374 data_send.push_back(std::move(cms.send_buf));
375 }
376
378 = shamalgs::collective::build_sparse_exchange_table(messages_send, max_alloc_size);
379
380 if (dev_sched->ctx->device->mpi_prop.is_mpi_direct_capable) {
381 cache.set_sizes<sham::device>(
382 dev_sched, comm_table2.send_total_sizes, comm_table2.recv_total_sizes);
383 } else {
384 cache.set_sizes<sham::host>(
385 dev_sched, comm_table2.send_total_sizes, comm_table2.recv_total_sizes);
386 }
387
388 if (comm_table2.messages_send.size() != data_send.size()) {
389 std::vector<size_t> tmp1{};
390 for (size_t i = 0; i < data_send.size(); i++) {
391 tmp1.push_back(comm_table2.messages_send[i].message_size);
392 }
393
394 std::vector<size_t> tmp2{};
395 for (size_t i = 0; i < data_send.size(); i++) {
396 tmp2.push_back(data_send[i]->get_size());
397 }
398
399 throw make_except_with_loc<std::runtime_error>(
400 shambase::format("message send mismatch : {} != {}", tmp1, tmp2));
401 }
402
403 if (comm_table2.messages_send.size() != messages_send.size()) {
404 std::vector<size_t> tmp1{};
405 for (size_t i = 0; i < comm_table2.messages_send.size(); i++) {
406 tmp1.push_back(comm_table2.messages_send[i].message_size);
407 }
408
409 std::vector<size_t> tmp2{};
410 for (size_t i = 0; i < messages_send.size(); i++) {
411 tmp2.push_back(messages_send[i].message_size);
412 }
413 throw make_except_with_loc<std::runtime_error>(
414 shambase::format("message send mismatch : {} != {}", tmp1, tmp2));
415 }
416
417 for (size_t i = 0; i < comm_table2.messages_send.size(); i++) {
418 auto &msg_info = comm_table2.messages_send[i];
419 auto offset_info = shambase::get_check_ref(msg_info.message_bytebuf_offset_send);
420 auto &buf_src = shambase::get_check_ref(data_send.at(i));
421
422 SHAM_ASSERT(buf_src.get_size() == msg_info.message_size);
423
424 cache.send_cache_write_buf_at(offset_info.buf_id, offset_info.data_offset, buf_src);
425 }
426
427 if (dev_sched->ctx->device->mpi_prop.is_mpi_direct_capable) {
428 shamalgs::collective::sparse_exchange<sham::device>(
429 dev_sched,
430 cache.get_cache1<sham::device>(),
431 cache.get_cache2<sham::device>(),
432 comm_table2);
433 } else {
434 shamalgs::collective::sparse_exchange<sham::host>(
435 dev_sched,
436 cache.get_cache1<sham::host>(),
437 cache.get_cache2<sham::host>(),
438 comm_table2);
439 }
440
441 // make serializers from recv buffs
442 struct RecvPayloadSer {
443 i32 sender_ranks;
444 SerializeHelper ser;
445 };
446
447 std::vector<RecvPayloadSer> recv_payload_bufs;
448
449 for (auto &msg : comm_table2.messages_recv) {
450
451 u64 size = msg.message_size;
452 i32 sender = msg.rank_sender;
453 i32 receiver = msg.rank_receiver;
454
455 auto offset_info = shambase::get_check_ref(msg.message_bytebuf_offset_recv);
456
457 sham::DeviceBuffer<u8> recov(size, dev_sched);
458 cache.recv_cache_read_buf_at(offset_info.buf_id, offset_info.data_offset, size, recov);
459
460 recv_payload_bufs.push_back(
461 RecvPayloadSer{sender, SerializeHelper(dev_sched, std::move(recov))});
462 }
463
464 {
465 NamedStackEntry stack_loc2{"split recv comms"};
466 // deserialize into the shared distributed data
467 for (RecvPayloadSer &recv : recv_payload_bufs) {
468 u64 cnt_obj;
469 recv.ser.load(cnt_obj);
470 for (u32 i = 0; i < cnt_obj; i++) {
471 u64 sender, receiver, length;
472
473 recv.ser.load(sender);
474 recv.ser.load(receiver);
475 recv.ser.load(length);
476
477 { // check correctness ranks
478 i32 supposed_sender_rank = rank_getter(sender);
479 i32 real_sender_rank = recv.sender_ranks;
480 if (supposed_sender_rank != real_sender_rank) {
481 throw make_except_with_loc<std::runtime_error>(shambase::format(
482 "the rank do not matches {} != {}",
483 supposed_sender_rank,
484 real_sender_rank));
485 }
486 }
487
488 auto it = recv_distrib_data.add_obj(
489 sender, receiver, sham::DeviceBuffer<u8>(length, dev_sched));
490
491 recv.ser.load_buf(it->second, length);
492 }
493 }
494 }
495 }
496
497} // namespace shamalgs::collective
std::uint32_t u32
32 bit unsigned integer
std::uint64_t u64
64 bit unsigned integer
std::int32_t i32
32 bit integer
#define SHAM_ASSERT(x)
Shorthand for SHAM_ASSERT_NAMED without a message.
Definition assert.hpp:67
A buffer allocated in USM (Unified Shared Memory)
size_t get_size() const
Gets the number of elements in the buffer.
Shamrock communication buffers.
static sham::DeviceBuffer< u8 > convert_usm(CommunicationBuffer &&buf)
destroy the buffer and recover the held object
This header file contains utility functions related to exception handling in the code.
Namespace for internal details of the logs module.
@ host
Host memory.
@ device
Device memory.
namespace for basic c++ utilities
void throw_with_loc(std::string message, SourceLocation loc=SourceLocation{})
Throw an exception and append the source location to it.
T & get_check_ref(const std::unique_ptr< T > &ptr, SourceLocation loc=SourceLocation())
Takes a std::unique_ptr and returns a reference to the object it holds. It throws a std::runtime_erro...
Definition memory.hpp:110
auto extract_pointer(std::unique_ptr< T > &o, SourceLocation loc=SourceLocation()) -> T
extract content out of unique_ptr
Definition memory.hpp:227
std::string getenv_str_default_register(const char *env_var, std::string default_val, std::string desc)
Get the content of the environment variable if it exist and register it documentation,...
Definition env.hpp:88
i32 world_rank()
Gives the rank of the current process in the MPI communicator.
Definition worldInfo.cpp:40
This file contains the definition for the stacktrace related functionality.
#define __shamrock_stack_entry()
Macro to create a stack entry.
std::vector< size_t > recv_total_sizes
Total size of the recv buffer.
std::vector< size_t > send_total_sizes
Total size of the send buffer.
std::vector< CommMessageInfo > messages_send
Messages to send.