Shamrock 2025.10.0
Astrophysical Code
Loading...
Searching...
No Matches
distributedDataComm.hpp
Go to the documentation of this file.
1// -------------------------------------------------------//
2//
3// SHAMROCK code for hydrodynamics
4// Copyright (c) 2021-2026 Timothée David--Cléris <tim.shamrock@proton.me>
5// SPDX-License-Identifier: CeCILL Free Software License Agreement v2.1
6// Shamrock is licensed under the CeCILL 2.1 License, see LICENSE for more information
7//
8// -------------------------------------------------------//
9
10#pragma once
11
26#include "shamcomm/logs.hpp"
27#include <functional>
28#include <mpi.h>
29#include <optional>
30#include <stdexcept>
31#include <vector>
32
33namespace shamalgs::collective {
34
36
37 template<sham::USMKindTarget target>
39 std::vector<std::unique_ptr<sham::DeviceBuffer<u8, target>>> cache1;
40 std::vector<std::unique_ptr<sham::DeviceBuffer<u8, target>>> cache2;
41
42 void set_sizes(
43 sham::DeviceScheduler_ptr dev_sched,
44 const std::vector<size_t> &sizes_cache1,
45 const std::vector<size_t> &sizes_cache2) {
46
48
49 // ensure correct length
50 cache1.resize(sizes_cache1.size());
51 cache2.resize(sizes_cache2.size());
52
53 // if size is different, resize
54 for (size_t i = 0; i < sizes_cache1.size(); i++) {
55 if (cache1[i]) {
56 cache1[i]->resize(sizes_cache1[i], false);
57 } else {
58 cache1[i] = std::make_unique<sham::DeviceBuffer<u8, target>>(
59 sizes_cache1[i], dev_sched);
60 }
61 }
62 for (size_t i = 0; i < sizes_cache2.size(); i++) {
63 if (cache2[i]) {
64 cache2[i]->resize(sizes_cache2[i], false);
65 } else {
66 cache2[i] = std::make_unique<sham::DeviceBuffer<u8, target>>(
67 sizes_cache2[i], dev_sched);
68 }
69 }
70 }
71
72 inline void send_cache_write_buf_at(
73 size_t buf_id, size_t offset, const sham::DeviceBuffer<u8> &buf) {
75 0, buf.get_size(), shambase::get_check_ref(cache1[buf_id]), offset);
76 }
77
78 inline void send_cache_read_buf_at(
79 size_t buf_id, size_t offset, size_t size, sham::DeviceBuffer<u8> &buf) {
80 buf.resize(size);
81 shambase::get_check_ref(cache1[buf_id]).copy_range(offset, offset + size, buf);
82 }
83
84 inline void recv_cache_write_buf_at(
85 size_t buf_id, size_t offset, const sham::DeviceBuffer<u8> &buf) {
87 0, buf.get_size(), shambase::get_check_ref(cache2[buf_id]), offset);
88 }
89
90 inline void recv_cache_read_buf_at(
91 size_t buf_id, size_t offset, size_t size, sham::DeviceBuffer<u8> &buf) {
92 buf.resize(size);
93 shambase::get_check_ref(cache2[buf_id]).copy_range(offset, offset + size, buf);
94 }
95 };
96
97 struct DDSCommCache {
98 std::variant<DDSCommCacheTarget<sham::device>, DDSCommCacheTarget<sham::host>> cache;
99
100 template<sham::USMKindTarget target>
101 std::vector<std::unique_ptr<sham::DeviceBuffer<u8, target>>> &get_cache1() {
102 return shambase::get_check_ref(std::get_if<DDSCommCacheTarget<target>>(&cache)).cache1;
103 }
104
105 template<sham::USMKindTarget target>
106 std::vector<std::unique_ptr<sham::DeviceBuffer<u8, target>>> &get_cache2() {
107 return shambase::get_check_ref(std::get_if<DDSCommCacheTarget<target>>(&cache)).cache2;
108 }
109
110 template<sham::USMKindTarget target>
111 void set_sizes(
112 sham::DeviceScheduler_ptr dev_sched,
113 const std::vector<size_t> &sizes_cache1,
114 const std::vector<size_t> &sizes_cache2) {
115
117
118 // init if not there
119 if (std::get_if<DDSCommCacheTarget<target>>(&cache) == nullptr) {
121 }
122
123 std::get<DDSCommCacheTarget<target>>(cache).set_sizes(
124 dev_sched, sizes_cache1, sizes_cache2);
125 }
126
127 inline void send_cache_write_buf_at(
128 size_t buf_id, size_t offset, const sham::DeviceBuffer<u8> &buf) {
129 std::visit(
130 [&](auto &cache) {
131 cache.send_cache_write_buf_at(buf_id, offset, buf);
132 },
133 cache);
134 }
135
136 inline void send_cache_read_buf_at(
137 size_t buf_id, size_t offset, size_t size, sham::DeviceBuffer<u8> &buf) {
138 std::visit(
139 [&](auto &cache) {
140 cache.send_cache_read_buf_at(buf_id, offset, size, buf);
141 },
142 cache);
143 }
144
145 inline void recv_cache_write_buf_at(
146 size_t buf_id, size_t offset, const sham::DeviceBuffer<u8> &buf) {
147 std::visit(
148 [&](auto &cache) {
149 cache.recv_cache_write_buf_at(buf_id, offset, buf);
150 },
151 cache);
152 }
153
154 inline void recv_cache_read_buf_at(
155 size_t buf_id, size_t offset, size_t size, sham::DeviceBuffer<u8> &buf) {
156 std::visit(
157 [&](auto &cache) {
158 cache.recv_cache_read_buf_at(buf_id, offset, size, buf);
159 },
160 cache);
161 }
162 };
163
164 void distributed_data_sparse_comm(
165 std::shared_ptr<sham::DeviceScheduler> dev_sched,
166 SerializedDDataComm &send_ddistrib_data,
167 SerializedDDataComm &recv_distrib_data,
168 std::function<i32(u64)> rank_getter,
169 DDSCommCache &cache,
170 std::optional<SparseCommTable> comm_table = {},
171 size_t max_comm_size = i32_max - 1); // MPI msg size limit
172
173 template<class T>
174 inline void serialize_sparse_comm(
175 std::shared_ptr<sham::DeviceScheduler> dev_sched,
176 shambase::DistributedDataShared<T> &&send_distrib_data,
177 shambase::DistributedDataShared<T> &recv_distrib_data,
178 std::function<i32(u64)> rank_getter,
179 std::function<sham::DeviceBuffer<u8>(T &)> serialize,
180 std::function<T(sham::DeviceBuffer<u8> &&)> deserialize,
181 DDSCommCache &cache,
182 std::optional<SparseCommTable> comm_table = {}) {
183
184 StackEntry stack_loc{};
185
187 // allow move op for same rank
188 send_distrib_data.tranfer_all(
189 [&](u64 l, u64 r) {
190 return rank_getter(l) == rank_getter(r);
191 },
192 same_rank_tmp);
193
194 SerializedDDataComm dcomm_send
195 = send_distrib_data.template map<sham::DeviceBuffer<u8>>([&](u64, u64, T &obj) {
196 return serialize(obj);
197 });
198
199 SerializedDDataComm dcomm_recv;
200
201 distributed_data_sparse_comm(dev_sched, dcomm_send, dcomm_recv, rank_getter, cache);
202
203 recv_distrib_data = dcomm_recv.map<T>([&](u64, u64, sham::DeviceBuffer<u8> &buf) {
204 // exchange the buffer held by the distrib data and give it to the deserializer
205 return deserialize(std::move(buf));
206 });
207
208 shamlog_debug_ln(
209 "SparseComm", "skipped", same_rank_tmp.get_native().size(), "communications");
210
211 same_rank_tmp.tranfer_all(
212 [&](u64 l, u64 r) {
213 return true;
214 },
215 recv_distrib_data);
216 }
217
227 template<class T, class P>
230 std::vector<P> local_ids,
231 std::vector<P> global_ids,
232 std::function<u64(P)> id_getter) {
233 std::vector<T> vec_local(local_ids.size());
234 for (u32 i = 0; i < local_ids.size(); i++) {
235 vec_local[i] = src.get(id_getter(local_ids[i]));
236 }
237
238 std::vector<T> vec_global;
240 vec_local, get_mpi_type<T>(), vec_global, get_mpi_type<T>(), MPI_COMM_WORLD);
241
243 for (u32 i = 0; i < global_ids.size(); i++) {
244 ret.add_obj(id_getter(global_ids[i]), T(vec_global[i]));
245 }
246 return ret;
247 }
248
258 template<class T, class P>
261 std::vector<P> local_ids,
262 std::vector<P> global_ids,
263 std::function<u64(P)> id_getter) {
264
265 using Trepr = typename T::Tload_store_repr;
266 constexpr u32 reprsz = T::sz_load_store_repr;
267
268 std::vector<T> vec_local(local_ids.size() * reprsz);
269 for (u32 i = 0; i < local_ids.size(); i++) {
270 src.get(id_getter(local_ids[i])).store(i * reprsz, vec_local);
271 }
272
273 std::vector<T> vec_global;
275 vec_local, get_mpi_type<T>(), vec_global, get_mpi_type<T>(), MPI_COMM_WORLD);
276
278 for (u32 i = 0; i < global_ids.size(); i++) {
279 T tmp = T::load(i * reprsz, vec_global);
280 ret.add_obj(id_getter(global_ids[i]), std::move(tmp));
281 }
282 return ret;
283 }
284
285} // namespace shamalgs::collective
Container for objects shared between two distributed data patches.
std::uint32_t u32
32 bit unsigned integer
std::uint64_t u64
64 bit unsigned integer
std::int32_t i32
32 bit integer
A buffer allocated in USM (Unified Shared Memory)
void resize(size_t new_size, bool keep_data=true)
Resizes the buffer to a given size.
void copy_range_offset(size_t begin, size_t end, sham::DeviceBuffer< T, dest_target > &dest, size_t dest_offset) const
Copy a range of elements from the buffer to another buffer.
size_t get_size() const
Gets the number of elements in the buffer.
Container for objects shared between two distributed data elements.
DistributedDataShared< Tmap > map(std::function< Tmap(u64, u64, T &)> map_func)
Transform all objects to a new type using a mapping function.
std::multimap< std::pair< u64, u64 >, T > & get_native()
Get direct access to the underlying multimap container.
void tranfer_all(std::function< bool(u64, u64)> cd, DistributedDataShared &other)
Transfer objects to another container based on a condition.
Represents a collection of objects distributed across patches identified by a u64 id.
iterator add_obj(u64 id, T &&obj)
Adds a new object to the collection.
T & get(u64 id)
Returns a reference to an object in the collection.
shambase::DistributedData< T > fetch_all_storeload(shambase::DistributedData< T > &src, std::vector< P > local_ids, std::vector< P > global_ids, std::function< u64(P)> id_getter)
global ids = allgatherv(local_ids)
shambase::DistributedData< T > fetch_all_simple(shambase::DistributedData< T > &src, std::vector< P > local_ids, std::vector< P > global_ids, std::function< u64(P)> id_getter)
global ids = allgatherv(local_ids)
std::vector< int > vector_allgatherv(const std::vector< T > &send_vec, const MPI_Datatype &send_type, std::vector< T > &recv_vec, const MPI_Datatype &recv_type, const MPI_Comm comm)
allgatherv on vector with size query (size querying variant of vector_allgatherv_ks) //TODO add fault...
Definition exchanges.hpp:98
T & get_check_ref(const std::unique_ptr< T > &ptr, SourceLocation loc=SourceLocation())
Takes a std::unique_ptr and returns a reference to the object it holds. It throws a std::runtime_erro...
Definition memory.hpp:110
constexpr i32 i32_max
i32 max value
This file contains the definition for the stacktrace related functionality.
#define __shamrock_stack_entry()
Macro to create a stack entry.