33namespace shamalgs::collective {
37 template<sham::USMKindTarget target>
39 std::vector<std::unique_ptr<sham::DeviceBuffer<u8, target>>> cache1;
40 std::vector<std::unique_ptr<sham::DeviceBuffer<u8, target>>> cache2;
43 sham::DeviceScheduler_ptr dev_sched,
44 const std::vector<size_t> &sizes_cache1,
45 const std::vector<size_t> &sizes_cache2) {
50 cache1.resize(sizes_cache1.size());
51 cache2.resize(sizes_cache2.size());
54 for (
size_t i = 0; i < sizes_cache1.size(); i++) {
56 cache1[i]->resize(sizes_cache1[i],
false);
58 cache1[i] = std::make_unique<sham::DeviceBuffer<u8, target>>(
59 sizes_cache1[i], dev_sched);
62 for (
size_t i = 0; i < sizes_cache2.size(); i++) {
64 cache2[i]->resize(sizes_cache2[i],
false);
66 cache2[i] = std::make_unique<sham::DeviceBuffer<u8, target>>(
67 sizes_cache2[i], dev_sched);
72 inline void send_cache_write_buf_at(
78 inline void send_cache_read_buf_at(
84 inline void recv_cache_write_buf_at(
90 inline void recv_cache_read_buf_at(
100 template<sham::USMKindTarget target>
101 std::vector<std::unique_ptr<sham::DeviceBuffer<u8, target>>> &get_cache1() {
105 template<sham::USMKindTarget target>
106 std::vector<std::unique_ptr<sham::DeviceBuffer<u8, target>>> &get_cache2() {
110 template<sham::USMKindTarget target>
112 sham::DeviceScheduler_ptr dev_sched,
113 const std::vector<size_t> &sizes_cache1,
114 const std::vector<size_t> &sizes_cache2) {
123 std::get<DDSCommCacheTarget<target>>(cache).set_sizes(
124 dev_sched, sizes_cache1, sizes_cache2);
127 inline void send_cache_write_buf_at(
131 cache.send_cache_write_buf_at(buf_id, offset, buf);
136 inline void send_cache_read_buf_at(
140 cache.send_cache_read_buf_at(buf_id, offset, size, buf);
145 inline void recv_cache_write_buf_at(
149 cache.recv_cache_write_buf_at(buf_id, offset, buf);
154 inline void recv_cache_read_buf_at(
158 cache.recv_cache_read_buf_at(buf_id, offset, size, buf);
164 void distributed_data_sparse_comm(
165 std::shared_ptr<sham::DeviceScheduler> dev_sched,
168 std::function<
i32(
u64)> rank_getter,
170 std::optional<SparseCommTable> comm_table = {},
171 size_t max_comm_size =
i32_max - 1);
174 inline void serialize_sparse_comm(
175 std::shared_ptr<sham::DeviceScheduler> dev_sched,
178 std::function<
i32(
u64)> rank_getter,
182 std::optional<SparseCommTable> comm_table = {}) {
190 return rank_getter(l) == rank_getter(r);
194 SerializedDDataComm dcomm_send
195 = send_distrib_data.template map<sham::DeviceBuffer<u8>>([&](
u64,
u64, T &obj) {
196 return serialize(obj);
199 SerializedDDataComm dcomm_recv;
201 distributed_data_sparse_comm(dev_sched, dcomm_send, dcomm_recv, rank_getter, cache);
205 return deserialize(std::move(buf));
209 "SparseComm",
"skipped", same_rank_tmp.
get_native().size(),
"communications");
227 template<
class T,
class P>
230 std::vector<P> local_ids,
231 std::vector<P> global_ids,
232 std::function<
u64(P)> id_getter) {
233 std::vector<T> vec_local(local_ids.size());
234 for (
u32 i = 0; i < local_ids.size(); i++) {
235 vec_local[i] = src.
get(id_getter(local_ids[i]));
238 std::vector<T> vec_global;
240 vec_local, get_mpi_type<T>(), vec_global, get_mpi_type<T>(), MPI_COMM_WORLD);
243 for (
u32 i = 0; i < global_ids.size(); i++) {
244 ret.
add_obj(id_getter(global_ids[i]), T(vec_global[i]));
258 template<
class T,
class P>
261 std::vector<P> local_ids,
262 std::vector<P> global_ids,
263 std::function<
u64(P)> id_getter) {
265 using Trepr =
typename T::Tload_store_repr;
266 constexpr u32 reprsz = T::sz_load_store_repr;
268 std::vector<T> vec_local(local_ids.size() * reprsz);
269 for (
u32 i = 0; i < local_ids.size(); i++) {
270 src.
get(id_getter(local_ids[i])).store(i * reprsz, vec_local);
273 std::vector<T> vec_global;
275 vec_local, get_mpi_type<T>(), vec_global, get_mpi_type<T>(), MPI_COMM_WORLD);
278 for (
u32 i = 0; i < global_ids.size(); i++) {
279 T tmp = T::load(i * reprsz, vec_global);
280 ret.
add_obj(id_getter(global_ids[i]), std::move(tmp));
Container for objects shared between two distributed data patches.
std::uint32_t u32
32 bit unsigned integer
std::uint64_t u64
64 bit unsigned integer
std::int32_t i32
32 bit integer
A buffer allocated in USM (Unified Shared Memory)
void resize(size_t new_size, bool keep_data=true)
Resizes the buffer to a given size.
void copy_range_offset(size_t begin, size_t end, sham::DeviceBuffer< T, dest_target > &dest, size_t dest_offset) const
Copy a range of elements from the buffer to another buffer.
size_t get_size() const
Gets the number of elements in the buffer.
Container for objects shared between two distributed data elements.
DistributedDataShared< Tmap > map(std::function< Tmap(u64, u64, T &)> map_func)
Transform all objects to a new type using a mapping function.
std::multimap< std::pair< u64, u64 >, T > & get_native()
Get direct access to the underlying multimap container.
void tranfer_all(std::function< bool(u64, u64)> cd, DistributedDataShared &other)
Transfer objects to another container based on a condition.
Represents a collection of objects distributed across patches identified by a u64 id.
iterator add_obj(u64 id, T &&obj)
Adds a new object to the collection.
T & get(u64 id)
Returns a reference to an object in the collection.
shambase::DistributedData< T > fetch_all_storeload(shambase::DistributedData< T > &src, std::vector< P > local_ids, std::vector< P > global_ids, std::function< u64(P)> id_getter)
global ids = allgatherv(local_ids)
shambase::DistributedData< T > fetch_all_simple(shambase::DistributedData< T > &src, std::vector< P > local_ids, std::vector< P > global_ids, std::function< u64(P)> id_getter)
global ids = allgatherv(local_ids)
std::vector< int > vector_allgatherv(const std::vector< T > &send_vec, const MPI_Datatype &send_type, std::vector< T > &recv_vec, const MPI_Datatype &recv_type, const MPI_Comm comm)
allgatherv on vector with size query (size querying variant of vector_allgatherv_ks) //TODO add fault...
T & get_check_ref(const std::unique_ptr< T > &ptr, SourceLocation loc=SourceLocation())
Takes a std::unique_ptr and returns a reference to the object it holds. It throws a std::runtime_erro...
constexpr i32 i32_max
i32 max value
This file contains the definition for the stacktrace related functionality.
#define __shamrock_stack_entry()
Macro to create a stack entry.