30 "SPARSE_COMM_MODE",
"new",
"Sparse communication mode (new=with cache, old=without cache)");
33 struct SparseCommMode {
34 enum Mode { NEW, OLD };
37 constexpr auto parse_sparse_comm_mode = []() {
38 if (SPARSE_COMM_MODE ==
"new") {
39 return SparseCommMode::NEW;
40 }
else if (SPARSE_COMM_MODE ==
"old") {
41 return SparseCommMode::OLD;
43 throw std::invalid_argument(
44 "Invalid sparse communication mode, valid modes are: new, old");
48 bool use_old_sparse_comm_mode = parse_sparse_comm_mode() == SparseCommMode::OLD;
50 bool warning_printed =
false;
53namespace shamalgs::collective {
63 return SerializeHelper::serialize_byte_size<u64>() * 3
64 + SerializeHelper::serialize_byte_size<u8>(length);
68 auto serialize_group_data(
69 std::shared_ptr<sham::DeviceScheduler> dev_sched,
70 std::map<std::pair<i32, i32>, std::vector<DataTmp>> &send_data)
77 for (
auto &[key, vect] : send_data) {
78 SerializeSize byte_sz = SerializeHelper::serialize_byte_size<u64>();
79 for (DataTmp &d : vect) {
80 byte_sz += d.get_ser_sz();
82 serializers.emplace(key, dev_sched);
83 serializers.at(key).allocate(byte_sz);
86 for (
auto &[key, vect] : send_data) {
87 SerializeHelper &ser = serializers.at(key);
88 ser.write<
u64>(vect.size());
89 for (DataTmp &d : vect) {
91 ser.write(d.receiver);
93 ser.write_buf(d.data, d.length);
102 i32 sender_rank, receiver_rank;
104 std::vector<std::reference_wrapper<DataTmp>> sources;
105 std::unique_ptr<SerializeHelper> serializer = {};
106 std::unique_ptr<sham::DeviceBuffer<u8>> send_buf = {};
108 void allocate_serializer(std::shared_ptr<sham::DeviceScheduler> dev_sched) {
109 serializer = std::make_unique<SerializeHelper>(dev_sched);
110 serializer->allocate(sz);
113 void write_sources() {
115 ser.write<
u64>(sources.size());
118 ser.write(d.receiver);
120 ser.write_buf(d.data, d.length);
124 void finalize_serializer() {
126 send_buf = std::make_unique<sham::DeviceBuffer<u8>>(ser.finalize());
130 auto serialize_group_data_max_size(
131 std::shared_ptr<sham::DeviceScheduler> dev_sched,
132 std::map<std::pair<i32, i32>, std::vector<DataTmp>> &send_data,
133 u64 max_comm_size) -> std::vector<PrepareCommUtil> {
137 std::vector<PrepareCommUtil> ret;
139 auto add_to_ret = [&](std::pair<i32, i32> key,
141 std::vector<std::reference_wrapper<DataTmp>> &sources) {
142 if (byte_sz.get_total_size() > max_comm_size) {
144 shambase::format(
"comm size too large: {}", byte_sz.get_total_size()));
147 auto [sender_rank, receiver_rank] = key;
149 if (sources.size() > 0) {
150 PrepareCommUtil next{sender_rank, receiver_rank, byte_sz, sources};
151 ret.push_back(std::move(next));
154 byte_sz = SerializeHelper::serialize_byte_size<u64>();
158 for (
auto &[key, vect] : send_data) {
159 SerializeSize byte_sz = SerializeHelper::serialize_byte_size<u64>();
160 std::vector<std::reference_wrapper<DataTmp>> sources = {};
162 for (DataTmp &d : vect) {
163 std::reference_wrapper<DataTmp> d_ref = d;
164 auto dbyte_sz = d.get_ser_sz();
166 if ((dbyte_sz + byte_sz).get_total_size() > max_comm_size) {
167 add_to_ret(key, byte_sz, sources);
174 byte_sz += d.get_ser_sz();
175 sources.push_back(d_ref);
178 add_to_ret(key, byte_sz, sources);
181 for (
auto &c : ret) {
185 c.allocate_serializer(dev_sched);
188 for (
auto &c : ret) {
192 for (
auto &c : ret) {
193 c.finalize_serializer();
201 void distributed_data_sparse_comm_old(
202 sham::DeviceScheduler_ptr dev_sched,
203 SerializedDDataComm &send_distrib_data,
204 SerializedDDataComm &recv_distrib_data,
205 std::function<
i32(
u64)> rank_getter,
206 std::optional<SparseCommTable> comm_table) {
211 using DataTmp = details::DataTmp;
214 std::map<std::pair<i32, i32>, std::vector<DataTmp>> send_data;
216 std::pair<i32, i32> key = {rank_getter(sender), rank_getter(receiver)};
218 send_data[key].push_back(DataTmp{sender, receiver, buf.
get_size(), buf});
222 std::map<std::pair<i32, i32>, SerializeHelper> serializers
223 = details::serialize_group_data(dev_sched, send_data);
226 std::map<std::pair<i32, i32>, std::unique_ptr<sham::DeviceBuffer<u8>>> send_bufs;
229 for (
auto &[key, ser] : serializers) {
230 send_bufs[key] = std::make_unique<sham::DeviceBuffer<u8>>(ser.finalize());
235 std::vector<SendPayload> send_payoad;
238 for (
auto &[key, buf] : send_bufs) {
239 send_payoad.push_back(
241 std::make_unique<shamcomm::CommunicationBuffer>(
247 std::vector<RecvPayload> recv_payload;
250 sparse_comm_c(dev_sched, send_payoad, recv_payload, *comm_table);
252 base_sparse_comm(dev_sched, send_payoad, recv_payload);
256 struct RecvPayloadSer {
261 std::vector<RecvPayloadSer> recv_payload_bufs;
265 for (RecvPayload &payload : recv_payload) {
272 recv_payload_bufs.push_back(
274 payload.sender_ranks, SerializeHelper(dev_sched, std::move(buf))});
281 for (RecvPayloadSer &recv : recv_payload_bufs) {
283 recv.ser.load(cnt_obj);
284 for (
u32 i = 0; i < cnt_obj; i++) {
285 u64 sender, receiver, length;
287 recv.ser.load(sender);
288 recv.ser.load(receiver);
289 recv.ser.load(length);
292 i32 supposed_sender_rank = rank_getter(sender);
293 i32 real_sender_rank = recv.sender_ranks;
294 if (supposed_sender_rank != real_sender_rank) {
295 throw make_except_with_loc<std::runtime_error>(
296 "the rank do not matches");
300 auto it = recv_distrib_data.add_obj(
303 recv.ser.load_buf(it->second, length);
309 void distributed_data_sparse_comm(
310 sham::DeviceScheduler_ptr dev_sched,
311 SerializedDDataComm &send_distrib_data,
312 SerializedDDataComm &recv_distrib_data,
313 std::function<
i32(
u64)> rank_getter,
315 std::optional<SparseCommTable> comm_table,
316 size_t max_comm_size) {
318 if (use_old_sparse_comm_mode) {
320 logger::warn_ln(
"SparseComm",
"using old sparse communication mode");
321 warning_printed =
true;
323 return distributed_data_sparse_comm_old(
324 dev_sched, send_distrib_data, recv_distrib_data, rank_getter, comm_table);
330 using DataTmp = details::DataTmp;
332 size_t max_alloc_size;
333 if (dev_sched->ctx->device->mpi_prop.is_mpi_direct_capable) {
334 max_alloc_size = dev_sched->ctx->device->prop.max_mem_alloc_size_dev;
336 max_alloc_size = dev_sched->ctx->device->prop.max_mem_alloc_size_host;
340 if (max_alloc_size > max_comm_size) {
341 max_alloc_size = max_comm_size;
345 std::map<std::pair<i32, i32>, std::vector<DataTmp>> send_data;
347 std::pair<i32, i32> key = {rank_getter(sender), rank_getter(receiver)};
349 send_data[key].push_back(DataTmp{sender, receiver, buf.
get_size(), buf});
352 std::vector<details::PrepareCommUtil> prepared_comms
353 = details::serialize_group_data_max_size(dev_sched, send_data, max_comm_size);
355 std::vector<shamalgs::collective::CommMessageInfo> messages_send;
356 std::vector<std::unique_ptr<sham::DeviceBuffer<u8>>> data_send;
358 for (
auto &cms : prepared_comms) {
360 auto sender = cms.sender_rank;
361 auto receiver = cms.receiver_rank;
364 messages_send.push_back(
374 data_send.push_back(std::move(cms.send_buf));
378 = shamalgs::collective::build_sparse_exchange_table(messages_send, max_alloc_size);
380 if (dev_sched->ctx->device->mpi_prop.is_mpi_direct_capable) {
389 std::vector<size_t> tmp1{};
390 for (
size_t i = 0; i < data_send.size(); i++) {
394 std::vector<size_t> tmp2{};
395 for (
size_t i = 0; i < data_send.size(); i++) {
396 tmp2.push_back(data_send[i]->get_size());
399 throw make_except_with_loc<std::runtime_error>(
400 shambase::format(
"message send mismatch : {} != {}", tmp1, tmp2));
403 if (comm_table2.
messages_send.size() != messages_send.size()) {
404 std::vector<size_t> tmp1{};
405 for (
size_t i = 0; i < comm_table2.
messages_send.size(); i++) {
409 std::vector<size_t> tmp2{};
410 for (
size_t i = 0; i < messages_send.size(); i++) {
411 tmp2.push_back(messages_send[i].message_size);
413 throw make_except_with_loc<std::runtime_error>(
414 shambase::format(
"message send mismatch : {} != {}", tmp1, tmp2));
417 for (
size_t i = 0; i < comm_table2.
messages_send.size(); i++) {
422 SHAM_ASSERT(buf_src.get_size() == msg_info.message_size);
424 cache.send_cache_write_buf_at(offset_info.buf_id, offset_info.data_offset, buf_src);
427 if (dev_sched->ctx->device->mpi_prop.is_mpi_direct_capable) {
428 shamalgs::collective::sparse_exchange<sham::device>(
434 shamalgs::collective::sparse_exchange<sham::host>(
442 struct RecvPayloadSer {
447 std::vector<RecvPayloadSer> recv_payload_bufs;
449 for (
auto &msg : comm_table2.messages_recv) {
451 u64 size = msg.message_size;
452 i32 sender = msg.rank_sender;
453 i32 receiver = msg.rank_receiver;
458 cache.recv_cache_read_buf_at(offset_info.buf_id, offset_info.data_offset, size, recov);
460 recv_payload_bufs.push_back(
461 RecvPayloadSer{sender, SerializeHelper(dev_sched, std::move(recov))});
467 for (RecvPayloadSer &recv : recv_payload_bufs) {
469 recv.ser.load(cnt_obj);
470 for (
u32 i = 0; i < cnt_obj; i++) {
471 u64 sender, receiver, length;
473 recv.ser.load(sender);
474 recv.ser.load(receiver);
475 recv.ser.load(length);
478 i32 supposed_sender_rank = rank_getter(sender);
479 i32 real_sender_rank = recv.sender_ranks;
480 if (supposed_sender_rank != real_sender_rank) {
481 throw make_except_with_loc<std::runtime_error>(shambase::format(
482 "the rank do not matches {} != {}",
483 supposed_sender_rank,
488 auto it = recv_distrib_data.add_obj(
491 recv.ser.load_buf(it->second, length);
std::uint32_t u32
32 bit unsigned integer
std::uint64_t u64
64 bit unsigned integer
std::int32_t i32
32 bit integer
#define SHAM_ASSERT(x)
Shorthand for SHAM_ASSERT_NAMED without a message.
A buffer allocated in USM (Unified Shared Memory)
size_t get_size() const
Gets the number of elements in the buffer.
Shamrock communication buffers.
static sham::DeviceBuffer< u8 > convert_usm(CommunicationBuffer &&buf)
destroy the buffer and recover the held object
This header file contains utility functions related to exception handling in the code.
Namespace for internal details of the logs module.
namespace for basic c++ utilities
void throw_with_loc(std::string message, SourceLocation loc=SourceLocation{})
Throw an exception and append the source location to it.
T & get_check_ref(const std::unique_ptr< T > &ptr, SourceLocation loc=SourceLocation())
Takes a std::unique_ptr and returns a reference to the object it holds. It throws a std::runtime_erro...
auto extract_pointer(std::unique_ptr< T > &o, SourceLocation loc=SourceLocation()) -> T
extract content out of unique_ptr
std::string getenv_str_default_register(const char *env_var, std::string default_val, std::string desc)
Get the content of the environment variable if it exist and register it documentation,...
i32 world_rank()
Gives the rank of the current process in the MPI communicator.
This file contains the definition for the stacktrace related functionality.
#define __shamrock_stack_entry()
Macro to create a stack entry.
std::vector< size_t > recv_total_sizes
Total size of the recv buffer.
std::vector< size_t > send_total_sizes
Total size of the send buffer.
std::vector< CommMessageInfo > messages_send
Messages to send.