32namespace shamalgs::collective {
45 const std::vector<T> &send_vec,
46 const MPI_Datatype send_type,
47 std::vector<T> &recv_vec,
48 const MPI_Datatype recv_type,
49 const MPI_Comm comm) {
51 u32 local_count = send_vec.size();
56 shamcomm::mpi::Allgather(&local_count, 1, MPI_INT, &table_data_count[0], 1, MPI_INT, comm);
63 node_displacements_data_table[0] = 0;
66 node_displacements_data_table[i]
67 = node_displacements_data_table[i - 1] + table_data_count[i - 1];
73 shamcomm::mpi::Allgatherv(
79 node_displacements_data_table,
83 delete[] table_data_count;
84 delete[] node_displacements_data_table;
99 const std::vector<T> &send_vec,
100 const MPI_Datatype &send_type,
101 std::vector<T> &recv_vec,
102 const MPI_Datatype &recv_type,
103 const MPI_Comm comm) {
108 if (comm == MPI_COMM_WORLD) {
111 MPICHECK(MPI_Comm_size(comm, &comm_size));
116 std::vector<int> table_data_count(
static_cast<std::size_t
>(comm_size));
118 shamcomm::mpi::Allgather(
119 &local_count, 1, MPI_INT, table_data_count.data(), 1, MPI_INT, comm);
125 shamcomm::mpi::Allreduce(
126 &local_count, &global_len, 1, MPI_INT, MPI_SUM, comm);
129 u64 tmp = std::accumulate(table_data_count.begin(), table_data_count.end(), 0_u64);
137 recv_vec.resize(global_len);
139 if (global_len == 0) {
144 std::vector<int> node_displacements_data_table(
static_cast<std::size_t
>(comm_size));
146 table_data_count.begin(),
147 table_data_count.end(),
148 node_displacements_data_table.begin(),
151 shamcomm::mpi::Allgatherv(
156 table_data_count.data(),
157 node_displacements_data_table.data(),
161 return node_displacements_data_table;
176 const std::vector<T> &send_vec,
177 const MPI_Datatype &send_type,
178 std::vector<T> &recv_vec,
179 const MPI_Datatype &recv_type,
184 if (comm != MPI_COMM_WORLD) {
188 u64 send_offset = 0_u64;
194 = (send_offset < send_vec.size()) ? (send_vec.size() - send_offset) : 0_u64;
195 u64 num_to_send = std::min<u64>(com_per_step, remaining);
196 std::vector<T> send_vec_internal(
197 send_vec.begin() + send_offset, send_vec.begin() + send_offset + num_to_send);
198 send_offset += num_to_send;
200 std::vector<T> recv_vec_internal{};
202 send_vec_internal, send_type, recv_vec_internal, recv_type, comm);
206 for (
u32 i = 0; i < (disp.size() - 1); i++) {
207 auto insert_loc = recv_vec.begin() + result_disps[i + 1] + disp[i];
210 recv_vec_internal.begin() + disp[i],
211 recv_vec_internal.begin() + disp[i + 1]);
212 result_disps[i] += disp[i];
214 result_disps[disp.size() - 1] += disp[disp.size() - 1];
228 const std::vector<T> &send_vec, std::vector<T> &recv_vec,
const MPI_Comm comm) {
229 vector_allgatherv(send_vec, get_mpi_type<T>(), recv_vec, get_mpi_type<T>(), comm);
std::uint32_t u32
32 bit unsigned integer
std::uint64_t u64
64 bit unsigned integer
Collective boolean reduction to check if all ranks have true as input.
bool are_all_rank_true(bool input, MPI_Comm comm)
return true only if all ranks have true as input
void vector_allgatherv_ks(const std::vector< T > &send_vec, const MPI_Datatype send_type, std::vector< T > &recv_vec, const MPI_Datatype recv_type, const MPI_Comm comm)
allgatherv with knowing total count of object //TODO add fault tolerance
std::vector< int > vector_allgatherv(const std::vector< T > &send_vec, const MPI_Datatype &send_type, std::vector< T > &recv_vec, const MPI_Datatype &recv_type, const MPI_Comm comm)
allgatherv on vector with size query (size querying variant of vector_allgatherv_ks) //TODO add fault...
void vector_allgatherv_large(const std::vector< T > &send_vec, const MPI_Datatype &send_type, std::vector< T > &recv_vec, const MPI_Datatype &recv_type, const MPI_Comm comm, u32 com_per_step=(1_i32<< 29)/static_cast< u32 >(shamcomm::world_size()))
vector_allgatherv version that support having more than 2^31 elements in flight
Utility functions for MPI error checking.
#define MPICHECK(mpicall)
Shortcut macro to check MPI return codes.
Use this header to include MPI properly.
void throw_with_loc(std::string message, SourceLocation loc=SourceLocation{})
Throw an exception and append the source location to it.
i32 world_size()
Gives the size of the MPI communicator.
Utilities for safe type narrowing conversions.
This file contains the definition for the stacktrace related functionality.
Functions related to the MPI communicator.