Shamrock 2025.10.0
Astrophysical Code
Loading...
Searching...
No Matches
CommunicationBuffer.cpp
Go to the documentation of this file.
1// -------------------------------------------------------//
2//
3// SHAMROCK code for hydrodynamics
4// Copyright (c) 2021-2026 Timothée David--Cléris <tim.shamrock@proton.me>
5// SPDX-License-Identifier: CeCILL Free Software License Agreement v2.1
6// Shamrock is licensed under the CeCILL 2.1 License, see LICENSE for more information
7//
8// -------------------------------------------------------//
9
18#include "shambackends/math.hpp"
19#include "shamcomm/logs.hpp"
20#include "shamcomm/mpi.hpp"
22#include "shamcomm/wrapper.hpp"
23
24namespace shamcomm {
25
26 bool validate_comm_internal(std::shared_ptr<sham::DeviceScheduler> &device_sched) {
27
28 u32 nbytes = 1e5;
29 sham::DeviceBuffer<u8> buf_comp(nbytes, device_sched);
30
31 {
32 std::vector<u8> host_data(nbytes);
33 for (u32 i = 0; i < nbytes; i++) {
34 host_data[i] = i % 100;
35 }
36 buf_comp.copy_from_stdvec(host_data);
37 }
38
39 shamcomm::CommunicationBuffer cbuf{buf_comp, device_sched};
40 shamcomm::CommunicationBuffer cbuf_recv{nbytes, device_sched};
41
42 MPI_Request rq1, rq2;
44 shamcomm::mpi::Isend(cbuf.get_ptr(), nbytes, MPI_BYTE, 0, 0, MPI_COMM_WORLD, &rq1);
45 }
46
47 if (shamcomm::world_rank() == 0) {
49 cbuf_recv.get_ptr(),
50 nbytes,
51 MPI_BYTE,
53 0,
54 MPI_COMM_WORLD,
55 &rq2);
56 }
57
59 shamcomm::mpi::Wait(&rq1, MPI_STATUS_IGNORE);
60 }
61
62 if (shamcomm::world_rank() == 0) {
63 shamcomm::mpi::Wait(&rq2, MPI_STATUS_IGNORE);
64 }
65
67 = shamcomm::CommunicationBuffer::convert_usm(std::move(cbuf_recv));
68
69 bool valid = true;
70
71 if (shamcomm::world_rank() == 0) {
72 std::vector<u8> acc1 = buf_comp.copy_to_stdvec();
73 std::vector<u8> acc2 = recv.copy_to_stdvec();
74
75 std::string id_err_list = "errors in id : ";
76
77 bool eq = true;
78 for (u32 i = 0; i < acc1.size(); i++) {
79 if (!sham::equals(acc1[i], acc2[i])) {
80 eq = false;
81 // id_err_list += std::to_string(i) + " ";
82 }
83 }
84
85 valid = eq;
86 }
87
88 return valid;
89 }
90
91 void validate_comm(std::shared_ptr<sham::DeviceScheduler> &sched) {
92 u32 nbytes = 1e5;
93
94 bool call_abort = false;
95
96 bool dgpu_mode = sched->ctx->device->mpi_prop.is_mpi_direct_capable;
97
98 using namespace shambase::term_colors;
99 if (dgpu_mode) {
100 if (validate_comm_internal(sched)) {
101 if (shamcomm::world_rank() == 0)
102 logger::raw_ln(" - MPI use Direct Comm :", col8b_green() + "Working" + reset());
103 } else {
104 if (shamcomm::world_rank() == 0)
105 logger::raw_ln(" - MPI use Direct Comm :", col8b_red() + "Fail" + reset());
106 if (shamcomm::world_rank() == 0)
107 logger::err_ln("Sys", "the select comm mode failed, try forcing dgpu mode off");
108 call_abort = true;
109 }
110 } else {
111 if (validate_comm_internal(sched)) {
112 if (shamcomm::world_rank() == 0)
113 logger::raw_ln(
114 " - MPI use Copy to Host :", col8b_green() + "Working" + reset());
115 } else {
116 if (shamcomm::world_rank() == 0)
117 logger::raw_ln(" - MPI use Copy to Host :", col8b_red() + "Fail" + reset());
118 call_abort = true;
119 }
120 }
121
122 shamcomm::mpi::Barrier(MPI_COMM_WORLD);
123
124 if (call_abort) {
125 MPI_Abort(MPI_COMM_WORLD, 26);
126 }
127 }
128
129} // namespace shamcomm
Shamrock communication buffers.
std::uint32_t u32
32 bit unsigned integer
A buffer allocated in USM (Unified Shared Memory)
std::vector< T > copy_to_stdvec() const
Copy the content of the buffer to a std::vector.
Shamrock communication buffers.
static sham::DeviceBuffer< u8 > convert_usm(CommunicationBuffer &&buf)
destroy the buffer and recover the held object
Use this header to include MPI properly.
namespace for communication related stuff
i32 world_rank()
Gives the rank of the current process in the MPI communicator.
Definition worldInfo.cpp:40
i32 world_size()
Gives the size of the MPI communicator.
Definition worldInfo.cpp:38
const std::string reset()
Get the reset terminal escape char.
const std::string col8b_green()
Get the green terminal escape char.
const std::string col8b_red()
Get the red terminal escape char.
Functions related to the MPI communicator.
void Irecv(void *buf, int count, MPI_Datatype datatype, int source, int tag, MPI_Comm comm, MPI_Request *request)
MPI wrapper for MPI_Irecv.
Definition wrapper.cpp:102
void Barrier(MPI_Comm comm)
MPI wrapper for MPI_Barrier.
Definition wrapper.cpp:194
void Wait(MPI_Request *request, MPI_Status *status)
MPI wrapper for MPI_Wait.
Definition wrapper.cpp:180
void Isend(const void *buf, int count, MPI_Datatype datatype, int dest, int tag, MPI_Comm comm, MPI_Request *request)
MPI wrapper for MPI_Isend.
Definition wrapper.cpp:85