Shamrock 2025.10.0
Astrophysical Code
Loading...
Searching...
No Matches
NodeInstance.cpp
Go to the documentation of this file.
1// -------------------------------------------------------//
2//
3// SHAMROCK code for hydrodynamics
4// Copyright (c) 2021-2026 Timothée David--Cléris <tim.shamrock@proton.me>
5// SPDX-License-Identifier: CeCILL Free Software License Agreement v2.1
6// Shamrock is licensed under the CeCILL 2.1 License, see LICENSE for more information
7//
8// -------------------------------------------------------//
9
18#include "shambase/memory.hpp"
20#include "shambase/string.hpp"
27#include "shambackends/math.hpp"
30#include "shamcmdopt/cmdopt.hpp"
31#include "shamcmdopt/env.hpp"
32#include "shamcmdopt/tty.hpp"
34#include "shamcomm/logs.hpp"
35#include "shamcomm/mpi.hpp"
36#include "shamcomm/mpiInfo.hpp"
46#include <memory>
47#include <optional>
48#include <stdexcept>
49#include <string>
50
51namespace shamsys::instance::details {
52
53 void print_device_list() {
55
56 std::string print_buf = "";
57
58 for_each_device([&](u32 key_global, const sycl::platform &plat, const sycl::device &dev) {
59 auto PlatformName = plat.get_info<sycl::info::platform::name>();
60 auto DeviceName = dev.get_info<sycl::info::device::name>();
61
62 std::string devname = shambase::trunc_str(DeviceName, 29);
63 std::string platname = shambase::trunc_str(PlatformName, 24);
64 std::string devtype = shambase::trunc_str(shambase::getDevice_type(dev), 6);
65
66 print_buf += shambase::format(
67 "| {:>4} | {:>2} | {:>29.29} | {:>24.24} | {:>6} |",
68 rank,
69 key_global,
70 devname,
71 platname,
72 devtype)
73 + "\n";
74 });
75
76 std::string recv;
77 shamalgs::collective::gather_str(print_buf, recv);
78
79 if (rank == 0) {
80 std::string print = "Cluster SYCL Info : \n";
81 print += ("----------------------------------------------------------------------------"
82 "----\n");
83 print += ("| rank | id | Device name | Platform name | "
84 "Type |\n");
85 print += ("----------------------------------------------------------------------------"
86 "----\n");
87 print += (recv);
88 print += ("----------------------------------------------------------------------------"
89 "----");
90 printf("%s\n", print.data());
91 }
92 }
93
94} // namespace shamsys::instance::details
95
96namespace syclinit {
97
98 bool initialized = false;
99
100 std::shared_ptr<sham::Device> device_compute;
101 std::shared_ptr<sham::Device> device_alt;
102
103 std::shared_ptr<sham::DeviceContext> ctx_compute;
104 std::shared_ptr<sham::DeviceContext> ctx_alt;
105
106 std::shared_ptr<sham::DeviceScheduler> sched_compute;
107 std::shared_ptr<sham::DeviceScheduler> sched_alt;
108
109 std::string callback_mem_perf_info() {
110 // in principle we should do it for both sched_compute and sched_alt
111 // but for now we only do one since it will return the same info
112 return "Memory usage & performance info:\n"
113 + sham::details::log_mem_perf_info(sched_compute);
114 }
115
116 void init_device_scheduling() {
117 StackEntry stack_loc{false};
118 ctx_compute = std::make_shared<sham::DeviceContext>(device_compute);
119 ctx_alt = std::make_shared<sham::DeviceContext>(device_alt);
120
121 sched_compute = std::make_shared<sham::DeviceScheduler>(ctx_compute);
122 sched_alt = std::make_shared<sham::DeviceScheduler>(ctx_alt);
123
124 test_device_scheduler(sched_compute);
125 test_device_scheduler(sched_alt);
126
127 shambase::add_callstack_gen_info_generator(callback_mem_perf_info);
128
129 // logger::raw_ln("--- Compute ---");
130 // sched_compute->print_info();
131 // logger::raw_ln("--- Alternative ---");
132 // sched_alt->print_info();
133 }
134
135 void init_queues(std::string search_key) {
136 StackEntry stack_loc{false};
137
138 auto devs = shamsys::select_devices(search_key);
139
140 device_alt = std::move(devs.device_alt);
141 device_compute = std::move(devs.device_compute);
142
143 init_device_scheduling();
144 initialized = true;
145 }
146
147 void finalize() {
148 initialized = false;
149
150 device_compute.reset();
151 device_alt.reset();
152
153 ctx_compute.reset();
154 ctx_alt.reset();
155
156 sched_compute.reset();
157 sched_alt.reset();
158 }
159}; // namespace syclinit
160
161namespace shamsys::instance {
162
163 u32 compute_queue_eu_count = 64;
164
165 u32 get_compute_queue_eu_count(u32 id) { return compute_queue_eu_count; }
166
167 bool is_initialized() { return syclinit::initialized && shamcomm::is_mpi_initialized(); };
168
170 u32 rank = shamcomm::world_rank();
171
172 std::string print_buf = "";
173
174 std::optional<u32> loc = shamcomm::node_local_rank();
175 if (loc) {
176 print_buf = shambase::format(
177 "| {:>4} | {:>8} | {:>12} | {:>16} |\n",
178 rank,
179 *loc,
180 shambase::get_check_ref(syclinit::device_alt).device_id,
181 shambase::get_check_ref(syclinit::device_compute).device_id);
182 } else {
183 print_buf = shambase::format(
184 "| {:>4} | {:>8} | {:>12} | {:>16} |\n",
185 rank,
186 "???",
187 shambase::get_check_ref(syclinit::device_alt).device_id,
188 shambase::get_check_ref(syclinit::device_compute).device_id);
189 }
190
191 std::string recv;
192 shamalgs::collective::gather_str(print_buf, recv);
193
194 if (rank == 0) {
195 std::string print = "Queue map : \n";
196 print += ("----------------------------------------------------\n");
197 print += ("| rank | local id | alt queue id | compute queue id |\n");
198 print += ("----------------------------------------------------\n");
199 print += (recv);
200 print += ("----------------------------------------------------");
201 printf("%s\n\n", print.data());
202 }
203 }
204
205 namespace tmp {
206
207 void print_device_list_debug() {
208 u32 rank = 0;
209
210 std::string print_buf = "device avail : \n";
211
213 [&](u32 key_global, const sycl::platform &plat, const sycl::device &dev) {
214 auto PlatformName = plat.get_info<sycl::info::platform::name>();
215 auto DeviceName = dev.get_info<sycl::info::device::name>();
216
217 std::string devname = DeviceName;
218 std::string platname = PlatformName;
219 std::string devtype = "truc";
220
221 print_buf += std::to_string(key_global) + " " + devname + " " + platname + "\n";
222 });
223
224 shamlog_debug_sycl_ln("InitSYCL", print_buf);
225 }
226
227 } // namespace tmp
228
229 void start_sycl_auto(std::string search_key) {
230 // start sycl
231
232 tmp::print_device_list_debug();
233
234 if (syclinit::initialized) {
235 throw ShamsysInstanceException("Sycl is already initialized");
236 }
237
238 if (shamcomm::world_rank() == 0) {
239 shamlog_debug_ln("Sys", "start sycl queues ...");
240 }
241
242 syclinit::init_queues(search_key);
243 }
244
245 void start_mpi(MPIInitInfo mpi_info) {
246
247 shamcomm::fetch_mpi_capabilities(mpi_info.forced_state);
248
249 mpi::init(&mpi_info.argc, &mpi_info.argv);
250
252
253 // now that MPI is started we can use the formatter with rank info
255
256 if (shamcomm::world_size() < 1) {
257 throw ShamsysInstanceException("world size is < 1");
258 }
259
260 if (shamcomm::world_rank() < 0) {
261 throw ShamsysInstanceException("world size is above i32_max");
262 }
263
264 int error;
265 // error = mpi::comm_set_errhandler(MPI_COMM_WORLD, MPI_ERRORS_RETURN);
266 error = mpi::comm_set_errhandler(MPI_COMM_WORLD, MPI_ERRORS_ARE_FATAL);
267
268 if (error != MPI_SUCCESS) {
269 throw ShamsysInstanceException("failed setting the MPI error mode");
270 }
271
272 shamlog_debug_ln(
273 "Sys",
274 shambase::format(
275 "[{:03}]: \x1B[32mMPI_Init : node n {:03} | world size : {} | name = {}\033[0m",
280
281 mpi::barrier(MPI_COMM_WORLD);
282 // if(world_rank == 0){
283 if (shamcomm::world_rank() == 0) {
284 shamlog_debug_ln("NodeInstance", "------------ MPI init ok ------------");
285 shamlog_debug_ln("NodeInstance", "creating MPI type for interop");
286 }
287 create_sycl_mpi_types();
288 if (shamcomm::world_rank() == 0) {
289 shamlog_debug_ln("NodeInstance", "MPI type for interop created");
290 shamlog_debug_ln("NodeInstance", "------------ MPI / SYCL init ok ------------");
291 }
292 mpidtypehandler::init_mpidtype();
293 }
294
295 auto init_strategy = shamcmdopt::getenv_str_default_register(
296 "SHAM_MPI_INIT_STRATEGY",
297 "syclfirst",
298 "Select the MPI init strategy (mpifirst, syclfirst) [default: syclfirst]");
299
300 void init_sycl_mpi(std::string search_key, MPIInitInfo mpi_info) {
301
302 if (init_strategy == "syclfirst") {
303 start_sycl_auto(search_key);
304 start_mpi(mpi_info);
305 } else if (init_strategy == "mpifirst") {
306 start_mpi(mpi_info);
307 start_sycl_auto(search_key);
308 } else {
310 }
311
312 shambase::get_check_ref(syclinit::device_compute).update_mpi_prop();
313 shambase::get_check_ref(syclinit::device_alt).update_mpi_prop();
314 }
315
316 void init(int argc, char *argv[]) {
317
318 std::optional<shamcomm::StateMPI_Aware> forced_state = std::nullopt;
319
320 if (shamcmdopt::has_option("--force-dgpu-on")) {
322 }
323
324 if (shamcmdopt::has_option("--force-dgpu-off")) {
326 }
327
328 if (opts::has_option("--sycl-cfg")) {
329
330 std::string sycl_cfg = std::string(opts::get_option("--sycl-cfg"));
331
332 // shamlog_debug_ln("NodeInstance", "chosen sycl config :",sycl_cfg);
333
334 init_sycl_mpi(sycl_cfg, {argc, argv, forced_state});
335
336 } else {
337
338 logger::err_ln("NodeInstance", "Please specify a sycl configuration (--sycl-cfg x:x)");
339 // std::cout << "[NodeInstance] Please specify a sycl configuration (--sycl-cfg x:x)" <<
340 // std::endl;
341 throw ShamsysInstanceException("Sycl Handler need configuration (--sycl-cfg x:x)");
342 }
343 }
344
345 void close_mpi() {
346 mpidtypehandler::free_mpidtype();
347
348 free_sycl_mpi_types();
349
350 if (shamcomm::world_rank() == 0) {
351 logger::print_faint_row();
352 logger::raw_ln(" - MPI finalize \nExiting ...\n");
353 logger::raw_ln(" Hopefully it was quick :')\n");
354 }
355
356 mpi::finalize();
357 }
358
359 void close() {
360
361 close_mpi();
362
363 syclinit::finalize();
364 }
365
367 // sycl related routines
369
370 sycl::queue &get_compute_queue(u32 /*id*/) { return syclinit::sched_compute->get_queue().q; }
371
372 sycl::queue &get_alt_queue(u32 /*id*/) { return syclinit::sched_alt->get_queue().q; }
373
374 sham::DeviceScheduler &get_compute_scheduler() { return *syclinit::sched_compute; }
375
376 sham::DeviceScheduler &get_alt_scheduler() { return *syclinit::sched_alt; }
377
378 std::shared_ptr<sham::DeviceScheduler> get_compute_scheduler_ptr() {
379 return syclinit::sched_compute;
380 }
381
382 std::shared_ptr<sham::DeviceScheduler> get_alt_scheduler_ptr() { return syclinit::sched_alt; }
383
384 void print_device_info(const sycl::device &Device) {
385 std::cout << " - " << Device.get_info<sycl::info::device::name>() << " "
387 Device.get_info<sycl::info::device::global_mem_size>())
388 << "\n";
389 }
390
391 void print_device_list() { details::print_device_list(); }
392
394 // MPI related routines
396
398
400
401 void check_dgpu_available() {
402
403 using namespace shambase::term_colors;
404
405 u32 loc_use_direct_gpu
406 = shambase::get_check_ref(syclinit::device_compute).mpi_prop.is_mpi_direct_capable;
407
408 u32 num_dgpu_use = shamalgs::collective::allreduce_sum(loc_use_direct_gpu);
409
410 if (shamcomm::world_rank() == 0) {
411 if (num_dgpu_use == shamcomm::world_size()) {
412 logger::raw_ln(
413 shambase::format(
414 " - MPI use Direct Comm : {}", col8b_green() + "Yes" + reset()));
415 } else if (num_dgpu_use > 0) {
416 logger::raw_ln(
417 shambase::format(
418 " - MPI use Direct Comm : {} ({} of {})",
419 col8b_yellow() + "Partial" + reset(),
420 num_dgpu_use,
422 } else {
423 logger::raw_ln(
424 shambase::format(" - MPI use Direct Comm : {}", col8b_red() + "No" + reset()));
425 }
426 }
427 }
428
429} // namespace shamsys::instance
Shamrock communication buffers.
This header does the MPI include and wrap MPI calls.
Header file describing a Node Instance.
void start_mpi(MPIInitInfo mpi_info)
Start MPI.
void print_queue_map()
Print SYCL queue map.
sycl::queue & get_compute_queue(u32 id=0)
void init_sycl_mpi(std::string search_key, MPIInitInfo mpi_info)
Start SYCL & MPI.
sycl::queue & get_alt_queue(u32 id=0)
Get the alternative queue.
bool is_initialized()
to check whether the NodeInstance is initialized
void close()
close the NodeInstance Aka : Finalize both MPI & SYCL
void close_mpi()
Finalize MPI.
Source location utility.
std::uint32_t u32
32 bit unsigned integer
Class to manage the scheduling of kernels on a device.
Exception type for the NodeInstance.
This header file contains utility functions related to exception handling in the code.
MPI string gather / allgather helpers (declarations; implementations in shamalgs/src/collective/gathe...
void gather_str(const std::string &send_vec, std::string &recv_vec)
Gathers a string from all nodes and store the result in a std::string.
Functions related to the MPI communicator.
Provide information about MPI capabilities.
Use this header to include MPI properly.
void print_buf(sycl::buffer< T > &buf, u32 len, u32 column_count, std::string_view fmt)
Print the content of a sycl::buffer
Definition memory.hpp:181
void print()
Prints a log message with no arguments.
std::string readable_sizeof(double size)
given a sizeof value return a readble string Example : readable_sizeof(1024*1024*1024) -> "1....
Definition string.hpp:139
std::string trunc_str(std::string s, u32 max_len)
Truncate a string to a specified length, adding an ellipsis if necessary.
Definition string.hpp:215
T & get_check_ref(const std::unique_ptr< T > &ptr, SourceLocation loc=SourceLocation())
Takes a std::unique_ptr and returns a reference to the object it holds. It throws a std::runtime_erro...
Definition memory.hpp:110
std::string getDevice_type(const sycl::device &Device)
Get the Device Type Name.
void throw_unimplemented(SourceLocation loc=SourceLocation{})
Throw a std::runtime_error saying that the function is unimplemented.
std::string getenv_str_default_register(const char *env_var, std::string default_val, std::string desc)
Get the content of the environment variable if it exist and register it documentation,...
Definition env.hpp:88
bool has_option(const std::string_view &option_name)
Check if an option is present.
Definition cmdopt.cpp:128
std::string_view get_option(const std::string_view &option_name)
Get the value of an option.
Definition cmdopt.cpp:146
void print_mpi_comm_info()
Print the MPI communicator infos.
Definition mpiInfo.cpp:142
void fetch_world_info()
Gets the information about the MPI communicator.
Definition worldInfo.cpp:64
i32 world_rank()
Gives the rank of the current process in the MPI communicator.
Definition worldInfo.cpp:40
@ ForcedYes
Feature forced on by the user.
Definition mpiInfo.hpp:44
@ ForcedNo
Feature forced off by the user.
Definition mpiInfo.hpp:48
i32 world_size()
Gives the size of the MPI communicator.
Definition worldInfo.cpp:38
void fetch_mpi_capabilities(std::optional< StateMPI_Aware > forced_state)
Fetch the MPI capabilities.
Definition mpiInfo.cpp:70
std::string get_process_name()
Get the process name.
Definition mpiInfo.cpp:147
bool is_mpi_initialized()
Check if MPI is initialized.
Definition worldInfo.cpp:89
void print_mpi_capabilities()
Print the MPI capabilities.
Definition mpiInfo.cpp:115
DeviceSelectRet_t select_devices(std::string sycl_cfg)
Select the devices for the queues.
void change_log_format()
Change the log formatter according to the SHAMLOGFORMATTER and SHAMLOG_ERR_ON_EXCEPT environment vari...
u32 for_each_device(std::function< void(u32, const sycl::platform &, const sycl::device &)> fct)
Iterate over all SYCL devices and perform a given function.
This file contains the definition for the stacktrace related functionality.
Struct containing MPI Init informations Usage.
const std::string reset()
Get the reset terminal escape char.
const std::string col8b_yellow()
Get the yellow terminal escape char.
const std::string col8b_green()
Get the green terminal escape char.
const std::string col8b_red()
Get the red terminal escape char.
This file contains tty info getters.
Functions related to the MPI communicator.