Shamrock 2025.10.0
Astrophysical Code
Loading...
Searching...
No Matches
reduction.hpp
Go to the documentation of this file.
1// -------------------------------------------------------//
2//
3// SHAMROCK code for hydrodynamics
4// Copyright (c) 2021-2026 Timothée David--Cléris <tim.shamrock@proton.me>
5// SPDX-License-Identifier: CeCILL Free Software License Agreement v2.1
6// Shamrock is licensed under the CeCILL 2.1 License, see LICENSE for more information
7//
8// -------------------------------------------------------//
9
10#pragma once
11
24#include "shamcomm/mpi.hpp"
27#include "shamcomm/wrapper.hpp"
28#include <type_traits>
29#include <stdexcept>
30#include <utility>
31
32namespace shamalgs::collective {
33
34 template<class T>
35 inline T allreduce_one(T a, MPI_Op op, MPI_Comm comm) {
36 T ret;
37 shamcomm::mpi::Allreduce(&a, &ret, 1, get_mpi_type<T>(), op, comm);
38 return ret;
39 }
40
41 template<class T, int n>
42 inline sycl::vec<T, n> allreduce_one(sycl::vec<T, n> a, MPI_Op op, MPI_Comm comm) {
43 sycl::vec<T, n> ret;
44 if constexpr (n == 2) {
45 shamcomm::mpi::Allreduce(&a.x(), &ret.x(), 1, get_mpi_type<T>(), op, comm);
46 shamcomm::mpi::Allreduce(&a.y(), &ret.y(), 1, get_mpi_type<T>(), op, comm);
47 } else if constexpr (n == 3) {
48 shamcomm::mpi::Allreduce(&a.x(), &ret.x(), 1, get_mpi_type<T>(), op, comm);
49 shamcomm::mpi::Allreduce(&a.y(), &ret.y(), 1, get_mpi_type<T>(), op, comm);
50 shamcomm::mpi::Allreduce(&a.z(), &ret.z(), 1, get_mpi_type<T>(), op, comm);
51 } else {
53 }
54 return ret;
55 }
56
57 template<class T>
58 inline T allreduce_sum(T a) {
59 return allreduce_one(a, MPI_SUM, MPI_COMM_WORLD);
60 }
61
62 template<class T>
63 inline T allreduce_min(T a) {
64 return allreduce_one(a, MPI_MIN, MPI_COMM_WORLD);
65 }
66
67 template<class T>
68 inline T allreduce_max(T a) {
69 return allreduce_one(a, MPI_MAX, MPI_COMM_WORLD);
70 }
71
72 template<class T>
73 inline std::pair<T, T> allreduce_bounds(std::pair<T, T> bounds) {
74 return {allreduce_min(bounds.first), allreduce_max(bounds.second)};
75 }
76
77 template<class T, sham::USMKindTarget target>
78 inline void reduce_buffer_in_place_sum(sham::DeviceBuffer<T, target> &field, MPI_Comm comm) {
79
81 auto flat = shamalgs::primitives::flatten_buffer(field);
82 reduce_buffer_in_place_sum(flat, comm);
83 field = shamalgs::primitives::unflatten_buffer<T, target>(flat);
84 } else {
85
86 if (field.get_size() > size_t(i32_max)) {
88 "MPI message are limited to i32_max in size");
89 }
90
91 if constexpr (target == sham::device) {
92
93 if (field.get_dev_scheduler().use_direct_comm()) {
94 sham::EventList depends_list;
95 T *ptr = field.get_write_access(depends_list);
96
97 depends_list.wait_and_throw();
98
100 MPI_IN_PLACE, ptr, field.get_size(), get_mpi_type<T>(), MPI_SUM, comm);
101
102 field.complete_event_state(sycl::event{});
103 } else {
105 = field.template copy_to<sham::host>();
106 reduce_buffer_in_place_sum(field_host, comm);
107 field.copy_from(field_host);
108 }
109
110 } else if (target == sham::host) {
111
112 sham::EventList depends_list;
113 T *ptr = field.get_write_access(depends_list);
114
115 depends_list.wait_and_throw();
116
118 MPI_IN_PLACE, ptr, field.get_size(), get_mpi_type<T>(), MPI_SUM, comm);
119
120 field.complete_event_state(sycl::event{});
121 } else {
123 }
124 }
125 }
126
127 template<class T>
128 inline std::vector<T> gather(T a, MPI_Comm comm = MPI_COMM_WORLD, int root = 0) {
129 std::vector<T> ret(shamcomm::world_size());
131 &a, 1, get_mpi_type<T>(), ret.data(), 1, get_mpi_type<T>(), root, comm);
132 return ret;
133 }
134
135} // namespace shamalgs::collective
A buffer allocated in USM (Unified Shared Memory)
void complete_event_state(sycl::event e) const
Complete the event state of the buffer.
T * get_write_access(sham::EventList &depends_list, SourceLocation src_loc=SourceLocation{})
Get a read-write pointer to the buffer's data.
void copy_from(const DeviceBuffer< T, new_target > &other, size_t copy_size)
Copies the content of another buffer to this one.
size_t get_size() const
Gets the number of elements in the buffer.
DeviceScheduler & get_dev_scheduler() const
Gets the Device scheduler corresponding to the held allocation.
bool use_direct_comm()
Check if the context corresponding to the device scheduler should use direct communication.
Class to manage a list of SYCL events.
Definition EventList.hpp:31
void wait_and_throw()
Wait for all events in the list to be finished and throw an exception if one has occurred.
Definition EventList.hpp:72
This header file contains utility functions related to exception handling in the code.
Utility functions for MPI error checking.
Use this header to include MPI properly.
@ host
Host memory.
@ device
Device memory.
sham::DeviceBuffer< typename shambase::VectorProperties< Tvec >::component_type, target > flatten_buffer(const sham::DeviceBuffer< Tvec, target > &buffer)
Flatten a buffer of vector type into a buffer of scalar type.
Definition flatten.hpp:40
void throw_with_loc(std::string message, SourceLocation loc=SourceLocation{})
Throw an exception and append the source location to it.
void throw_unimplemented(SourceLocation loc=SourceLocation{})
Throw a std::runtime_error saying that the function is unimplemented.
i32 world_size()
Gives the size of the MPI communicator.
Definition worldInfo.cpp:38
constexpr i32 i32_max
i32 max value
Functions related to the MPI communicator.
void Allreduce(const void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
MPI wrapper for MPI_Allreduce.
Definition wrapper.cpp:119
void Gather(const void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf, int recvcount, MPI_Datatype recvtype, int root, MPI_Comm comm)
MPI wrapper for MPI_Gather.
Definition wrapper.cpp:326