32namespace shamalgs::collective {
35 inline T allreduce_one(T a, MPI_Op op, MPI_Comm comm) {
41 template<
class T,
int n>
42 inline sycl::vec<T, n> allreduce_one(sycl::vec<T, n> a, MPI_Op op, MPI_Comm comm) {
44 if constexpr (n == 2) {
47 }
else if constexpr (n == 3) {
58 inline T allreduce_sum(T a) {
59 return allreduce_one(a, MPI_SUM, MPI_COMM_WORLD);
63 inline T allreduce_min(T a) {
64 return allreduce_one(a, MPI_MIN, MPI_COMM_WORLD);
68 inline T allreduce_max(T a) {
69 return allreduce_one(a, MPI_MAX, MPI_COMM_WORLD);
73 inline std::pair<T, T> allreduce_bounds(std::pair<T, T> bounds) {
74 return {allreduce_min(bounds.first), allreduce_max(bounds.second)};
77 template<
class T, sham::USMKindTarget target>
78 inline void reduce_buffer_in_place_sum(sham::DeviceBuffer<T, target> &field, MPI_Comm comm) {
80 if constexpr (shambase::VectorProperties<T>::dimension > 1) {
82 reduce_buffer_in_place_sum(flat, comm);
88 "MPI message are limited to i32_max in size");
94 sham::EventList depends_list;
100 MPI_IN_PLACE, ptr, field.
get_size(), get_mpi_type<T>(), MPI_SUM, comm);
104 sham::DeviceBuffer<T, sham::host> field_host
105 = field.template copy_to<sham::host>();
106 reduce_buffer_in_place_sum(field_host, comm);
112 sham::EventList depends_list;
118 MPI_IN_PLACE, ptr, field.
get_size(), get_mpi_type<T>(), MPI_SUM, comm);
128 inline std::vector<T> gather(T a, MPI_Comm comm = MPI_COMM_WORLD,
int root = 0) {
131 &a, 1, get_mpi_type<T>(), ret.data(), 1, get_mpi_type<T>(), root, comm);
void complete_event_state(sycl::event e) const
Complete the event state of the buffer.
T * get_write_access(sham::EventList &depends_list, SourceLocation src_loc=SourceLocation{})
Get a read-write pointer to the buffer's data.
void copy_from(const DeviceBuffer< T, new_target > &other, size_t copy_size)
Copies the content of another buffer to this one.
size_t get_size() const
Gets the number of elements in the buffer.
DeviceScheduler & get_dev_scheduler() const
Gets the Device scheduler corresponding to the held allocation.
bool use_direct_comm()
Check if the context corresponding to the device scheduler should use direct communication.
void wait_and_throw()
Wait for all events in the list to be finished and throw an exception if one has occurred.
This header file contains utility functions related to exception handling in the code.
Utility functions for MPI error checking.
Use this header to include MPI properly.
sham::DeviceBuffer< Tvec, target > unflatten_buffer(const sham::DeviceBuffer< typename shambase::VectorProperties< Tvec >::component_type, target > &buffer)
Unflatten a buffer that contains a flattened vector.
sham::DeviceBuffer< typename shambase::VectorProperties< Tvec >::component_type, target > flatten_buffer(const sham::DeviceBuffer< Tvec, target > &buffer)
Flatten a buffer of vector type into a buffer of scalar type.
void throw_with_loc(std::string message, SourceLocation loc=SourceLocation{})
Throw an exception and append the source location to it.
ExcptTypes make_except_with_loc(std::string message, SourceLocation loc=SourceLocation{})
Create an exception with a message and a location.
void throw_unimplemented(SourceLocation loc=SourceLocation{})
Throw a std::runtime_error saying that the function is unimplemented.
i32 world_size()
Gives the size of the MPI communicator.
constexpr i32 i32_max
i32 max value
Functions related to the MPI communicator.
void Allreduce(const void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
MPI wrapper for MPI_Allreduce.
void Gather(const void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf, int recvcount, MPI_Datatype recvtype, int root, MPI_Comm comm)
MPI wrapper for MPI_Gather.