Shamrock 2025.10.0
Astrophysical Code
Loading...
Searching...
No Matches
exchanges.hpp
Go to the documentation of this file.
1// -------------------------------------------------------//
2//
3// SHAMROCK code for hydrodynamics
4// Copyright (c) 2021-2026 Timothée David--Cléris <tim.shamrock@proton.me>
5// SPDX-License-Identifier: CeCILL Free Software License Agreement v2.1
6// Shamrock is licensed under the CeCILL 2.1 License, see LICENSE for more information
7//
8// -------------------------------------------------------//
9
10#pragma once
11
24#include "shamcomm/logs.hpp"
25#include "shamcomm/mpi.hpp"
28#include "shamcomm/wrapper.hpp"
29#include <numeric>
30#include <vector>
31
32namespace shamalgs::collective {
33
43 template<class T>
45 const std::vector<T> &send_vec,
46 const MPI_Datatype send_type,
47 std::vector<T> &recv_vec,
48 const MPI_Datatype recv_type,
49 const MPI_Comm comm) {
50
51 u32 local_count = send_vec.size();
52
53 int *table_data_count = new int[shamcomm::world_size()];
54
55 // crash
56 shamcomm::mpi::Allgather(&local_count, 1, MPI_INT, &table_data_count[0], 1, MPI_INT, comm);
57
58 // printf("table_data_count =
59 // [%d,%d,%d,%d]\n",table_data_count[0],table_data_count[1],table_data_count[2],table_data_count[3]);
60
61 int *node_displacements_data_table = new int[shamcomm::world_size()];
62
63 node_displacements_data_table[0] = 0;
64
65 for (u32 i = 1; i < shamcomm::world_size(); i++) {
66 node_displacements_data_table[i]
67 = node_displacements_data_table[i - 1] + table_data_count[i - 1];
68 }
69
70 // printf("node_displacements_data_table =
71 // [%d,%d,%d,%d]\n",node_displacements_data_table[0],node_displacements_data_table[1],node_displacements_data_table[2],node_displacements_data_table[3]);
72
73 shamcomm::mpi::Allgatherv(
74 &send_vec[0],
75 send_vec.size(),
76 send_type,
77 &recv_vec[0],
78 table_data_count,
79 node_displacements_data_table,
80 recv_type,
81 comm);
82
83 delete[] table_data_count;
84 delete[] node_displacements_data_table;
85 }
86
97 template<class T>
98 inline std::vector<int> vector_allgatherv(
99 const std::vector<T> &send_vec,
100 const MPI_Datatype &send_type,
101 std::vector<T> &recv_vec,
102 const MPI_Datatype &recv_type,
103 const MPI_Comm comm) {
104 StackEntry stack_loc{};
105
106 int comm_size = 0;
107
108 if (comm == MPI_COMM_WORLD) {
109 comm_size = shamcomm::world_size();
110 } else {
111 MPICHECK(MPI_Comm_size(comm, &comm_size));
112 }
113
114 int local_count = shambase::narrow_or_throw<int>(send_vec.size());
115
116 std::vector<int> table_data_count(static_cast<std::size_t>(comm_size));
117
118 shamcomm::mpi::Allgather(
119 &local_count, 1, MPI_INT, table_data_count.data(), 1, MPI_INT, comm);
120
121 int global_len = 0;
122 // use work duplication or MPI reduction
123#if false
124 // query global size and resize the receiving vector
125 shamcomm::mpi::Allreduce(
126 &local_count, &global_len, 1, MPI_INT, MPI_SUM, comm);
127#else
128 {
129 u64 tmp = std::accumulate(table_data_count.begin(), table_data_count.end(), 0_u64);
130
131 // if it exceeds the max size of int, MPI will trip like crazy
132 // god damn it just implement 64bits indices ... Pleeeeeasssssse !!!
133 global_len = shambase::narrow_or_throw<int>(tmp);
134 }
135#endif
136
137 recv_vec.resize(global_len);
138
139 if (global_len == 0) {
140 return {};
141 }
142
143 // here we can not overflow since we know that the sum can be narrowed to an int
144 std::vector<int> node_displacements_data_table(static_cast<std::size_t>(comm_size));
145 std::exclusive_scan(
146 table_data_count.begin(),
147 table_data_count.end(),
148 node_displacements_data_table.begin(),
149 0);
150
151 shamcomm::mpi::Allgatherv(
152 send_vec.data(), // even if the size is 0 MPI does not care
153 local_count,
154 send_type,
155 recv_vec.data(),
156 table_data_count.data(),
157 node_displacements_data_table.data(),
158 recv_type,
159 comm);
160
161 return node_displacements_data_table;
162 }
163
174 template<class T>
176 const std::vector<T> &send_vec,
177 const MPI_Datatype &send_type,
178 std::vector<T> &recv_vec,
179 const MPI_Datatype &recv_type,
180 const MPI_Comm comm,
181 u32 com_per_step = (1_i32 << 29) / static_cast<u32>(shamcomm::world_size())) {
182
183 // check that comm is MPI_COMM_WORLD
184 if (comm != MPI_COMM_WORLD) {
185 throw shambase::make_except_with_loc<std::runtime_error>("comm must be MPI_COMM_WORLD");
186 }
187
188 u64 send_offset = 0_u64;
189 std::vector<u64> result_disps(shamcomm::world_size() + 1, 0_u64);
190
191 while (!shamalgs::collective::are_all_rank_true(send_offset == send_vec.size(), comm)) {
192 // extract com_per_step elements from send_vec
193 u64 remaining
194 = (send_offset < send_vec.size()) ? (send_vec.size() - send_offset) : 0_u64;
195 u64 num_to_send = std::min<u64>(com_per_step, remaining);
196 std::vector<T> send_vec_internal(
197 send_vec.begin() + send_offset, send_vec.begin() + send_offset + num_to_send);
198 send_offset += num_to_send;
199
200 std::vector<T> recv_vec_internal{};
201 auto disp = vector_allgatherv(
202 send_vec_internal, send_type, recv_vec_internal, recv_type, comm);
203 disp.push_back(shambase::narrow_or_throw<int>(recv_vec_internal.size()));
204
205 // The bit that insert in such a way that it reproduce vector_allgatherv
206 for (u32 i = 0; i < (disp.size() - 1); i++) {
207 auto insert_loc = recv_vec.begin() + result_disps[i + 1] + disp[i];
208 recv_vec.insert(
209 insert_loc,
210 recv_vec_internal.begin() + disp[i],
211 recv_vec_internal.begin() + disp[i + 1]);
212 result_disps[i] += disp[i];
213 }
214 result_disps[disp.size() - 1] += disp[disp.size() - 1];
215 }
216 }
217
226 template<class T>
227 inline void vector_allgatherv(
228 const std::vector<T> &send_vec, std::vector<T> &recv_vec, const MPI_Comm comm) {
229 vector_allgatherv(send_vec, get_mpi_type<T>(), recv_vec, get_mpi_type<T>(), comm);
230 }
231
232} // namespace shamalgs::collective
std::uint32_t u32
32 bit unsigned integer
std::uint64_t u64
64 bit unsigned integer
Collective boolean reduction to check if all ranks have true as input.
bool are_all_rank_true(bool input, MPI_Comm comm)
return true only if all ranks have true as input
void vector_allgatherv_ks(const std::vector< T > &send_vec, const MPI_Datatype send_type, std::vector< T > &recv_vec, const MPI_Datatype recv_type, const MPI_Comm comm)
allgatherv with knowing total count of object //TODO add fault tolerance
Definition exchanges.hpp:44
std::vector< int > vector_allgatherv(const std::vector< T > &send_vec, const MPI_Datatype &send_type, std::vector< T > &recv_vec, const MPI_Datatype &recv_type, const MPI_Comm comm)
allgatherv on vector with size query (size querying variant of vector_allgatherv_ks) //TODO add fault...
Definition exchanges.hpp:98
void vector_allgatherv_large(const std::vector< T > &send_vec, const MPI_Datatype &send_type, std::vector< T > &recv_vec, const MPI_Datatype &recv_type, const MPI_Comm comm, u32 com_per_step=(1_i32<< 29)/static_cast< u32 >(shamcomm::world_size()))
vector_allgatherv version that support having more than 2^31 elements in flight
Utility functions for MPI error checking.
#define MPICHECK(mpicall)
Shortcut macro to check MPI return codes.
Use this header to include MPI properly.
void throw_with_loc(std::string message, SourceLocation loc=SourceLocation{})
Throw an exception and append the source location to it.
i32 world_size()
Gives the size of the MPI communicator.
Definition worldInfo.cpp:38
Utilities for safe type narrowing conversions.
This file contains the definition for the stacktrace related functionality.
Functions related to the MPI communicator.