32namespace shamalgs::collective {
35 inline T allreduce_one(T a, MPI_Op op, MPI_Comm comm) {
41 template<
class T,
int n>
42 inline sycl::vec<T, n> allreduce_one(sycl::vec<T, n> a, MPI_Op op, MPI_Comm comm) {
44 if constexpr (n == 2) {
47 }
else if constexpr (n == 3) {
58 inline T allreduce_sum(T a) {
59 return allreduce_one(a, MPI_SUM, MPI_COMM_WORLD);
63 inline T allreduce_min(T a) {
64 return allreduce_one(a, MPI_MIN, MPI_COMM_WORLD);
68 inline T allreduce_max(T a) {
69 return allreduce_one(a, MPI_MAX, MPI_COMM_WORLD);
73 inline std::pair<T, T> allreduce_bounds(std::pair<T, T> bounds) {
74 return {allreduce_min(bounds.first), allreduce_max(bounds.second)};
77 template<
class T, sham::USMKindTarget target>
82 reduce_buffer_in_place_sum(flat, comm);
83 field = shamalgs::primitives::unflatten_buffer<T, target>(flat);
88 "MPI message are limited to i32_max in size");
100 MPI_IN_PLACE, ptr, field.
get_size(), get_mpi_type<T>(), MPI_SUM, comm);
105 = field.template copy_to<sham::host>();
106 reduce_buffer_in_place_sum(field_host, comm);
118 MPI_IN_PLACE, ptr, field.
get_size(), get_mpi_type<T>(), MPI_SUM, comm);
128 inline std::vector<T> gather(T a, MPI_Comm comm = MPI_COMM_WORLD,
int root = 0) {
131 &a, 1, get_mpi_type<T>(), ret.data(), 1, get_mpi_type<T>(), root, comm);
A buffer allocated in USM (Unified Shared Memory)
void complete_event_state(sycl::event e) const
Complete the event state of the buffer.
T * get_write_access(sham::EventList &depends_list, SourceLocation src_loc=SourceLocation{})
Get a read-write pointer to the buffer's data.
void copy_from(const DeviceBuffer< T, new_target > &other, size_t copy_size)
Copies the content of another buffer to this one.
size_t get_size() const
Gets the number of elements in the buffer.
DeviceScheduler & get_dev_scheduler() const
Gets the Device scheduler corresponding to the held allocation.
bool use_direct_comm()
Check if the context corresponding to the device scheduler should use direct communication.
Class to manage a list of SYCL events.
void wait_and_throw()
Wait for all events in the list to be finished and throw an exception if one has occurred.
This header file contains utility functions related to exception handling in the code.
Utility functions for MPI error checking.
Use this header to include MPI properly.
sham::DeviceBuffer< typename shambase::VectorProperties< Tvec >::component_type, target > flatten_buffer(const sham::DeviceBuffer< Tvec, target > &buffer)
Flatten a buffer of vector type into a buffer of scalar type.
void throw_with_loc(std::string message, SourceLocation loc=SourceLocation{})
Throw an exception and append the source location to it.
void throw_unimplemented(SourceLocation loc=SourceLocation{})
Throw a std::runtime_error saying that the function is unimplemented.
i32 world_size()
Gives the size of the MPI communicator.
constexpr i32 i32_max
i32 max value
Functions related to the MPI communicator.
void Allreduce(const void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
MPI wrapper for MPI_Allreduce.
void Gather(const void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf, int recvcount, MPI_Datatype recvtype, int root, MPI_Comm comm)
MPI wrapper for MPI_Gather.