Shamrock 2025.10.0
Astrophysical Code
Loading...
Searching...
No Matches
SyclMpiTypes.cpp
Go to the documentation of this file.
1// -------------------------------------------------------//
2//
3// SHAMROCK code for hydrodynamics
4// Copyright (c) 2021-2026 Timothée David--Cléris <tim.shamrock@proton.me>
5// SPDX-License-Identifier: CeCILL Free Software License Agreement v2.1
6// Shamrock is licensed under the CeCILL 2.1 License, see LICENSE for more information
7//
8// -------------------------------------------------------//
9
17#include "shamcomm/logs.hpp"
19
20bool __mpi_sycl_type_active = false;
21bool is_mpi_sycl_interop_active() { return __mpi_sycl_type_active; }
22
23/*
24const int __len_vec2 [] = {1,1};
25const int __len_vec3 [] = {1,1,1};
26const int __len_vec4 [] = {1,1,1,1};
27const int __len_vec8 [] = {1,1,1,1,1,1,1,1};
28const int __len_vec16 [] = {1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1};
29*/
30const int __len_vec2 = 2;
31const int __len_vec3 = 3;
32const int __len_vec4 = 4;
33const int __len_vec8 = 8;
34const int __len_vec16 = 16;
35
36inline MPI_Datatype __tmp_mpi_type_i64_3;
37inline MPI_Datatype __tmp_mpi_type_i32_3;
38inline MPI_Datatype __tmp_mpi_type_i16_3;
39inline MPI_Datatype __tmp_mpi_type_i8_3;
40inline MPI_Datatype __tmp_mpi_type_u64_3;
41inline MPI_Datatype __tmp_mpi_type_u32_3;
42inline MPI_Datatype __tmp_mpi_type_u16_3;
43inline MPI_Datatype __tmp_mpi_type_u8_3;
44inline MPI_Datatype __tmp_mpi_type_f16_3;
45inline MPI_Datatype __tmp_mpi_type_f32_3;
46inline MPI_Datatype __tmp_mpi_type_f64_3;
47
48#define __SYCL_TYPE_COMMIT_len2(base_name, src_type) \
49 { \
50 check_offset_validity<base_name>(); \
51 MPICHECK(MPI_Type_contiguous(__len_vec2, mpi_type_##src_type, &mpi_type_##base_name)); \
52 MPICHECK(MPI_Type_commit(&mpi_type_##base_name)); \
53 shamlog_debug_mpi_ln("SyclMpiTypes", "init mpi type for : " #base_name); \
54 }
55
56#define __SYCL_TYPE_COMMIT_len3(base_name, src_type) \
57 { \
58 check_offset_validity<base_name>(); \
59 MPICHECK( \
60 MPI_Type_contiguous(__len_vec3, mpi_type_##src_type, &__tmp_mpi_type_##base_name)); \
61 MPICHECK(MPI_Type_create_resized( \
62 __tmp_mpi_type_##base_name, 0, sizeof(base_name), &mpi_type_##base_name)); \
63 MPICHECK(MPI_Type_commit(&mpi_type_##base_name)); \
64 shamlog_debug_mpi_ln("SyclMpiTypes", "init mpi type for : " #base_name); \
65 }
66
67#define __SYCL_TYPE_COMMIT_len4(base_name, src_type) \
68 { \
69 check_offset_validity<base_name>(); \
70 MPICHECK(MPI_Type_contiguous(__len_vec4, mpi_type_##src_type, &mpi_type_##base_name)); \
71 MPICHECK(MPI_Type_commit(&mpi_type_##base_name)); \
72 shamlog_debug_mpi_ln("SyclMpiTypes", "init mpi type for : " #base_name); \
73 }
74
75#define __SYCL_TYPE_COMMIT_len8(base_name, src_type) \
76 { \
77 check_offset_validity<base_name>(); \
78 MPICHECK(MPI_Type_contiguous(__len_vec8, mpi_type_##src_type, &mpi_type_##base_name)); \
79 MPICHECK(MPI_Type_commit(&mpi_type_##base_name)); \
80 shamlog_debug_mpi_ln("SyclMpiTypes", "init mpi type for : " #base_name); \
81 }
82
83#define __SYCL_TYPE_COMMIT_len16(base_name, src_type) \
84 { \
85 check_offset_validity<base_name>(); \
86 MPICHECK(MPI_Type_contiguous(__len_vec16, mpi_type_##src_type, &mpi_type_##base_name)); \
87 MPICHECK(MPI_Type_commit(&mpi_type_##base_name)); \
88 shamlog_debug_mpi_ln("SyclMpiTypes", "init mpi type for : " #base_name); \
89 }
90
91template<class T>
92void check_offset_validity() {
93 T a{};
94
95 std::ptrdiff_t base = reinterpret_cast<std::ptrdiff_t>(&a);
96 std::ptrdiff_t s0 = reinterpret_cast<std::ptrdiff_t>(&a.s0());
97
98 if (s0 - base != 0) {
100 "Offset is not valid for type {}, base = {}, s0 = {}", typeid(T).name(), base, s0));
101 }
102}
103
104void create_sycl_mpi_types() {
105
106 __SYCL_TYPE_COMMIT_len2(i64_2, i64);
107 __SYCL_TYPE_COMMIT_len2(i32_2, i32);
108 __SYCL_TYPE_COMMIT_len2(i16_2, i16);
109 __SYCL_TYPE_COMMIT_len2(i8_2, i8);
110 __SYCL_TYPE_COMMIT_len2(u64_2, u64);
111 __SYCL_TYPE_COMMIT_len2(u32_2, u32);
112 __SYCL_TYPE_COMMIT_len2(u16_2, u16);
113 __SYCL_TYPE_COMMIT_len2(u8_2, u8);
114 __SYCL_TYPE_COMMIT_len2(f16_2, f16);
115 __SYCL_TYPE_COMMIT_len2(f32_2, f32);
116 __SYCL_TYPE_COMMIT_len2(f64_2, f64);
117
118 __SYCL_TYPE_COMMIT_len3(i64_3, i64);
119 __SYCL_TYPE_COMMIT_len3(i32_3, i32);
120 __SYCL_TYPE_COMMIT_len3(i16_3, i16);
121
122 // {
123 // i16_3 a;
124
125 // MPI_Datatype types_list[3] = {mpi_type_i16,mpi_type_i16,mpi_type_i16};
126 // int block_lens[3] = {1,1,1};
127 // MPI_Aint MPI_offset[3];
128 // MPI_offset[0] = ((size_t) ( (char *)&(a.x()) - (char *)&(a) ));
129 // MPI_offset[1] = ((size_t) ( (char *)&(a.y()) - (char *)&(a) ));
130 // MPI_offset[2] = ((size_t) ( (char *)&(a.z()) - (char *)&(a) ));
131
132 // mpi::type_create_struct( 3, block_lens, MPI_offset, types_list, &mpi_type_i16_3 );
133 // /*mpi::type_create_resized(__tmp_mpi_type_i16_3, 0, sizeof(base_name),
134 // &mpi_type_i16_3);*/\ mpi::type_commit( &mpi_type_i16_3 );
135 // }
136
137 __SYCL_TYPE_COMMIT_len3(i8_3, i8);
138 __SYCL_TYPE_COMMIT_len3(u64_3, u64);
139 __SYCL_TYPE_COMMIT_len3(u32_3, u32);
140 __SYCL_TYPE_COMMIT_len3(u16_3, u16);
141 __SYCL_TYPE_COMMIT_len3(u8_3, u8);
142 __SYCL_TYPE_COMMIT_len3(f16_3, f16);
143 __SYCL_TYPE_COMMIT_len3(f32_3, f32);
144 __SYCL_TYPE_COMMIT_len3(f64_3, f64);
145
146 __SYCL_TYPE_COMMIT_len4(i64_4, i64);
147 __SYCL_TYPE_COMMIT_len4(i32_4, i32);
148 __SYCL_TYPE_COMMIT_len4(i16_4, i16);
149 __SYCL_TYPE_COMMIT_len4(i8_4, i8);
150 __SYCL_TYPE_COMMIT_len4(u64_4, u64);
151 __SYCL_TYPE_COMMIT_len4(u32_4, u32);
152 __SYCL_TYPE_COMMIT_len4(u16_4, u16);
153 __SYCL_TYPE_COMMIT_len4(u8_4, u8);
154 __SYCL_TYPE_COMMIT_len4(f16_4, f16);
155 __SYCL_TYPE_COMMIT_len4(f32_4, f32);
156 __SYCL_TYPE_COMMIT_len4(f64_4, f64);
157
158 __SYCL_TYPE_COMMIT_len8(i64_8, i64);
159 __SYCL_TYPE_COMMIT_len8(i32_8, i32);
160 __SYCL_TYPE_COMMIT_len8(i16_8, i16);
161 __SYCL_TYPE_COMMIT_len8(i8_8, i8);
162 __SYCL_TYPE_COMMIT_len8(u64_8, u64);
163 __SYCL_TYPE_COMMIT_len8(u32_8, u32);
164 __SYCL_TYPE_COMMIT_len8(u16_8, u16);
165 __SYCL_TYPE_COMMIT_len8(u8_8, u8);
166 __SYCL_TYPE_COMMIT_len8(f16_8, f16);
167 __SYCL_TYPE_COMMIT_len8(f32_8, f32);
168 __SYCL_TYPE_COMMIT_len8(f64_8, f64);
169
170 __SYCL_TYPE_COMMIT_len16(i64_16, i64);
171 __SYCL_TYPE_COMMIT_len16(i32_16, i32);
172 __SYCL_TYPE_COMMIT_len16(i16_16, i16);
173 __SYCL_TYPE_COMMIT_len16(i8_16, i8);
174 __SYCL_TYPE_COMMIT_len16(u64_16, u64);
175 __SYCL_TYPE_COMMIT_len16(u32_16, u32);
176 __SYCL_TYPE_COMMIT_len16(u16_16, u16);
177 __SYCL_TYPE_COMMIT_len16(u8_16, u8);
178 __SYCL_TYPE_COMMIT_len16(f16_16, f16);
179 __SYCL_TYPE_COMMIT_len16(f32_16, f32);
180 __SYCL_TYPE_COMMIT_len16(f64_16, f64);
181
182 __mpi_sycl_type_active = true;
183}
184
185void free_sycl_mpi_types() {
186
187 MPICHECK(MPI_Type_free(&mpi_type_i64_2));
188 MPICHECK(MPI_Type_free(&mpi_type_i32_2));
189 MPICHECK(MPI_Type_free(&mpi_type_i16_2));
190 MPICHECK(MPI_Type_free(&mpi_type_i8_2));
191 MPICHECK(MPI_Type_free(&mpi_type_u64_2));
192 MPICHECK(MPI_Type_free(&mpi_type_u32_2));
193 MPICHECK(MPI_Type_free(&mpi_type_u16_2));
194 MPICHECK(MPI_Type_free(&mpi_type_u8_2));
195 MPICHECK(MPI_Type_free(&mpi_type_f16_2));
196 MPICHECK(MPI_Type_free(&mpi_type_f32_2));
197 MPICHECK(MPI_Type_free(&mpi_type_f64_2));
198
199 MPICHECK(MPI_Type_free(&mpi_type_i64_3));
200 MPICHECK(MPI_Type_free(&mpi_type_i32_3));
201 MPICHECK(MPI_Type_free(&mpi_type_i16_3));
202 MPICHECK(MPI_Type_free(&mpi_type_i8_3));
203 MPICHECK(MPI_Type_free(&mpi_type_u64_3));
204 MPICHECK(MPI_Type_free(&mpi_type_u32_3));
205 MPICHECK(MPI_Type_free(&mpi_type_u16_3));
206 MPICHECK(MPI_Type_free(&mpi_type_u8_3));
207 MPICHECK(MPI_Type_free(&mpi_type_f16_3));
208 MPICHECK(MPI_Type_free(&mpi_type_f32_3));
209 MPICHECK(MPI_Type_free(&mpi_type_f64_3));
210
211 MPICHECK(MPI_Type_free(&mpi_type_i64_4));
212 MPICHECK(MPI_Type_free(&mpi_type_i32_4));
213 MPICHECK(MPI_Type_free(&mpi_type_i16_4));
214 MPICHECK(MPI_Type_free(&mpi_type_i8_4));
215 MPICHECK(MPI_Type_free(&mpi_type_u64_4));
216 MPICHECK(MPI_Type_free(&mpi_type_u32_4));
217 MPICHECK(MPI_Type_free(&mpi_type_u16_4));
218 MPICHECK(MPI_Type_free(&mpi_type_u8_4));
219 MPICHECK(MPI_Type_free(&mpi_type_f16_4));
220 MPICHECK(MPI_Type_free(&mpi_type_f32_4));
221 MPICHECK(MPI_Type_free(&mpi_type_f64_4));
222
223 MPICHECK(MPI_Type_free(&mpi_type_i64_8));
224 MPICHECK(MPI_Type_free(&mpi_type_i32_8));
225 MPICHECK(MPI_Type_free(&mpi_type_i16_8));
226 MPICHECK(MPI_Type_free(&mpi_type_i8_8));
227 MPICHECK(MPI_Type_free(&mpi_type_u64_8));
228 MPICHECK(MPI_Type_free(&mpi_type_u32_8));
229 MPICHECK(MPI_Type_free(&mpi_type_u16_8));
230 MPICHECK(MPI_Type_free(&mpi_type_u8_8));
231 MPICHECK(MPI_Type_free(&mpi_type_f16_8));
232 MPICHECK(MPI_Type_free(&mpi_type_f32_8));
233 MPICHECK(MPI_Type_free(&mpi_type_f64_8));
234
235 MPICHECK(MPI_Type_free(&mpi_type_i64_16));
236 MPICHECK(MPI_Type_free(&mpi_type_i32_16));
237 MPICHECK(MPI_Type_free(&mpi_type_i16_16));
238 MPICHECK(MPI_Type_free(&mpi_type_i8_16));
239 MPICHECK(MPI_Type_free(&mpi_type_u64_16));
240 MPICHECK(MPI_Type_free(&mpi_type_u32_16));
241 MPICHECK(MPI_Type_free(&mpi_type_u16_16));
242 MPICHECK(MPI_Type_free(&mpi_type_u8_16));
243 MPICHECK(MPI_Type_free(&mpi_type_f16_16));
244 MPICHECK(MPI_Type_free(&mpi_type_f32_16));
245 MPICHECK(MPI_Type_free(&mpi_type_f64_16));
246
247 __mpi_sycl_type_active = false;
248}
double f64
Alias for double.
float f32
Alias for float.
std::int8_t i8
8 bit integer
std::uint8_t u8
8 bit unsigned integer
std::uint32_t u32
32 bit unsigned integer
std::uint64_t u64
64 bit unsigned integer
std::uint16_t u16
16 bit unsigned integer
std::int16_t i16
16 bit integer
std::int64_t i64
64 bit integer
std::int32_t i32
32 bit integer
Utility functions for MPI error checking.
#define MPICHECK(mpicall)
Shortcut macro to check MPI return codes.
void throw_with_loc(std::string message, SourceLocation loc=SourceLocation{})
Throw an exception and append the source location to it.