Shamrock 2025.10.0
Astrophysical Code
Loading...
Searching...
No Matches
shamtest.cpp
Go to the documentation of this file.
1// -------------------------------------------------------//
2//
3// SHAMROCK code for hydrodynamics
4// Copyright (c) 2021-2026 Timothée David--Cléris <tim.shamrock@proton.me>
5// SPDX-License-Identifier: CeCILL Free Software License Agreement v2.1
6// Shamrock is licensed under the CeCILL 2.1 License, see LICENSE for more information
7//
8// -------------------------------------------------------//
9
19#include "shambase/string.hpp"
21#include "shambase/time.hpp"
25#include "shamcmdopt/ci_env.hpp"
26#include "shamcmdopt/env.hpp"
27#include "shamcomm/logs.hpp"
29#include "shamrock/version.hpp"
34#include "shamtest.hpp"
37#include <pybind11/embed.h>
38#include <unordered_map>
39#include <cmath>
40#include <cstdlib>
41#include <filesystem>
42#include <sstream>
43#include <string>
44#include <vector>
45
46namespace shamtest {
47
48 bool is_run_only = false;
49 bool is_full_output_mode = false;
50
58 void _start_test_print(details::Test &test, u32 test_num, u32 test_count) {
59
60 std::string output;
61 if (is_run_only) {
62 output += ("- : ");
63 } else {
64 output += shambase::format("- [{}/{}] :", test_num + 1, test_count);
65 }
66
67 bool any_node_cnt = test.node_count == -1;
68 if (any_node_cnt) {
69 output += (" [any] ");
70 } else {
71 output += shambase::format(" [{:03}] ", test.node_count);
72 }
73
74 output += "\033[;34m" + test.name + "\033[0m\n";
75 ON_RANK_0(printf("%s", output.c_str()));
76 }
77
84 void _end_test_print(std::vector<details::TestResult> &rank_results, shambase::Timer &timer) {
85
86 for (int rank = 0; rank < rank_results.size(); rank++) {
87 auto &res = rank_results[rank];
88
89 for (unsigned int j = 0; j < res.asserts.asserts.size(); j++) {
90
91 if (is_full_output_mode || (!res.asserts.asserts[j].value)) {
92 printf(" Rank %3d [%d/%zu] : ", rank, j + 1, res.asserts.asserts.size());
93 printf("%-20s", res.asserts.asserts[j].name.c_str());
94
95 if (res.asserts.asserts[j].value) {
96 std::cout << " (\033[;32mSuccess\033[0m)\n";
97 } else {
98 std::cout << " (\033[1;31m Fail \033[0m)\n";
99 if (!res.asserts.asserts[j].comment.empty()) {
100 std::cout << "----- logs : \n"
101 << res.asserts.asserts[j].comment << "\n-----" << std::endl;
102 }
103 }
104 }
105 }
106 }
107
108 u32 assert_count = 0;
109 u32 success_cnt = 0;
110 for (int rank = 0; rank < rank_results.size(); rank++) {
111 auto &res = rank_results[rank];
112 for (unsigned int j = 0; j < res.asserts.asserts.size(); j++) {
113 if (res.asserts.asserts[j].value) {
114 success_cnt++;
115 }
116 assert_count++;
117 }
118 }
119
120 if (success_cnt == assert_count) {
121 std::cout << " -> Result : \033[;32mSuccess\033[0m";
122 } else {
123 std::cout << " -> Result : \033[1;31m Fail \033[0m";
124 }
125
126 std::string s_assert = shambase::format(" [{}/{}] ", success_cnt, assert_count);
127 printf("%-15s", s_assert.c_str());
128 std::cout << " (" << timer.get_time_str() << ")" << std::endl;
129
131 if (success_cnt != assert_count) {
132 logger::raw_ln(shambase::format("##[error]Test {} failed", rank_results[0].name));
133 }
134 }
135
136 std::cout << std::endl;
137 }
138
146 bool is_test_failed(details::TestResult &res) {
147 for (unsigned int j = 0; j < res.asserts.asserts.size(); j++) {
148
149 if (!res.asserts.asserts[j].value) {
150 return true;
151 }
152 }
153
154 return false;
155 }
156
163 std::string gen_fail_log(details::TestResult &res) {
164 std::string out = "";
165
166 std::string sep = "\n-------------------------------------\n";
167
168 out += " - Test : \033[;34m" + res.name
169 + "\033[0m world_rank = " + std::to_string(res.world_rank);
170 out += "\n Assertion list :\n";
171
172 for (unsigned int j = 0; j < res.asserts.asserts.size(); j++) {
173
174 out += shambase::format_printf(" - [%d/%zu] ", j + 1, res.asserts.asserts.size());
175 out += shambase::format_printf("%-20s", res.asserts.asserts[j].name.c_str());
176
177 if (res.asserts.asserts[j].value) {
178 out += " (\033[;32mSuccess\033[0m)\n";
179 } else {
180 out += " (\033[1;31m Fail \033[0m)\n";
181 }
182
183 if ((!res.asserts.asserts[j].value) && !res.asserts.asserts[j].comment.empty()) {
184 out += " -> failed assert logs : " + sep + res.asserts.asserts[j].comment + sep;
185 }
186 }
187
188 out += "\n\n";
189
190 return out;
191 }
192
198 void _print_summary(std::vector<details::TestResult> &results) {
199 if (shamcomm::world_rank() > 0) {
200 return;
201 }
202
203 logger::print_faint_row();
204 logger::print_faint_row();
205 logger::print_faint_row();
206 logger::raw_ln(
207 shambase::term_colors::bold() + "Test Report :" + shambase::term_colors::reset());
208 logger::raw_ln();
209
210 u32 test_count = results.size();
211 u32 succ_count = 0;
212 u32 fail_count = 0;
213 std::string log = "";
214 for (details::TestResult &res : results) {
215 if (!is_test_failed(res)) {
216 succ_count++;
217 } else {
218 fail_count++;
219 log += gen_fail_log(res);
220 }
221 }
222
223 std::cout << "Test suite status : ";
224 if (fail_count == 0) {
225 std::cout << " (\033[;32mSuccess\033[0m)";
226 printf(" [%d/%d] \n", succ_count, test_count);
227 } else {
228 std::cout << " (\033[1;31m Fail \033[0m)";
229 printf(" [%d/%d] \n", succ_count, test_count);
230 std::cout << "\nFailed tests : \n\n" << log;
231 }
232
233 logger::print_faint_row();
234 logger::print_faint_row();
235 logger::print_faint_row();
236 }
237
239 std::vector<details::TestResult> gather_tests(
240 std::vector<details::TestResult> rank_result, usize &gather_bytecount) {
241 if (shamcomm::world_size() == 1) {
242 return rank_result;
243 }
244
245 // generate payload
246 std::basic_stringstream<byte> outrank;
247
248 shambase::stream_write_vector(outrank, rank_result);
249
250 std::basic_string<byte> gathered;
251 shamalgs::collective::gather_basic_str(outrank.str(), gathered);
252
253 if (shamcomm::world_rank() != 0) {
254 return {};
255 }
256
257 gather_bytecount = gathered.size();
258
259 std::basic_stringstream<byte> reader(gathered);
260
261 std::vector<details::TestResult> out;
262
263 for (u32 i = 0; i < shamcomm::world_size(); i++) {
264 shambase::stream_read_vector(reader, out);
265 }
266
267 return out;
268 }
269
274 void print_test_list() {
275
276 if (shamcomm::world_rank() > 0) {
277 return;
278 }
279
280 using namespace shamtest::details;
281
282 auto print_list = [&](TestType t) {
283 for (auto test : static_init_vec_tests) {
284 if (test.type == t) {
285 if (test.node_count == -1) {
286 printf("- [any] %-15s\n", test.name.c_str());
287 } else {
288 printf("- [%03d] %-15s\n", test.node_count, test.name.c_str());
289 }
290 }
291 }
292 };
293
294 printf("--- Benchmark ---\n");
295
296 print_list(Benchmark);
297
298 printf("--- LongBenchmark ---\n");
299
300 print_list(LongBenchmark);
301
302 printf("--- ValidationTest ---\n");
303
304 print_list(ValidationTest);
305
306 printf("--- LongValidationTest ---\n");
307
308 print_list(LongValidationTest);
309
310 printf("--- Unittest ---\n");
311
312 print_list(Unittest);
313 }
314
316 void write_json_report(std::vector<details::TestResult> &results, std::string outfile) {
317 if (shamcomm::world_rank() > 0) {
318 return;
319 }
320
321 std::stringstream rank_test_res_out;
322 for (details::TestResult &res : results) {
323 rank_test_res_out << res.serialize_json() << ",";
324 }
325
326 std::string out_res_string = rank_test_res_out.str();
327
328 // generate json output and write it into the specified file
329
330 if (out_res_string.back() == ',') {
331 out_res_string = out_res_string.substr(0, out_res_string.size() - 1);
332 }
333
334 std::string s_out;
335
336 s_out = "{\n";
337
338 s_out += R"( "commit_hash" : ")" + git_commit_hash + "\",\n";
339 s_out += R"( "world_size" : ")" + std::to_string(shamcomm::world_size()) + "\",\n";
340
341#if defined(SYCL_COMP_INTEL_LLVM)
342 s_out += R"( "compiler" : "DPCPP",)"
343 "\n";
344#elif defined(SYCL_COMP_HIPSYCL)
345 s_out += R"( "compiler" : "HipSYCL",)"
346 "\n";
347#else
348 s_out += R"( "compiler" : "Unknown",)"
349 "\n";
350#endif
351
352 s_out += R"( "comp_args" : ")" + compile_arg + "\",\n";
353
354 s_out += R"( "results" : )"
355 "[\n\n";
356 s_out += shambase::increase_indent(out_res_string);
357 s_out += "\n ]\n}";
358
359 // printf("%s\n",s_out.c_str());
360
361 shambase::write_string_to_file(outfile, s_out);
362 }
363
365 void write_tex_report(std::vector<details::TestResult> &results, bool mark_fail) {
366 if (shamcomm::world_rank() > 0) {
367 return;
368 }
369
370 logger::raw("write report Tex : ");
371
373 "tests/report.tex", details::make_test_report_tex(results, mark_fail));
374
375 logger::raw_ln("Done (tests/report.tex)");
376 }
377
378 std::vector<u32> select_print_tests(TestConfig cfg) {
379
380 bool run_unit_test = cfg.run_unittest;
381 bool run_validation_test = cfg.run_validation;
382 bool run_longvalidation_test = cfg.run_validation && cfg.run_long_tests;
383 bool run_benchmark_test = cfg.run_benchmark;
384 bool run_longbenchmark_test = cfg.run_benchmark && cfg.run_long_tests;
385
386 auto can_run = [&](shamtest::details::Test &t) -> bool {
387 bool any_node_cnt = (t.node_count == -1);
388 bool world_size_ok = t.node_count == shamcomm::world_size();
389
390 bool can_run_type = false;
391
392 auto test_type = t.type;
393 can_run_type |= (run_unit_test && (Unittest == test_type));
394 can_run_type |= (run_validation_test && (ValidationTest == test_type));
395 can_run_type |= (run_longvalidation_test && (LongValidationTest == test_type));
396 can_run_type |= (run_benchmark_test && (Benchmark == test_type));
397 can_run_type |= (run_longbenchmark_test && (LongBenchmark == test_type));
398
399 return can_run_type && (any_node_cnt || world_size_ok);
400 };
401
402 auto print_test = [&](shamtest::details::Test &t, bool enabled) {
403 bool any_node_cnt = (t.node_count == -1);
404
405 std::string output = "";
406
407 if (enabled) {
408
409 if (any_node_cnt) {
410 output += (" - [\033[;32many\033[0m] ");
411 } else {
412 output += shambase::format(" - [\033[;32m{:03}\033[0m] ", t.node_count);
413 }
414 output += "\033[;32m" + t.name + "\033[0m\n";
415
416 } else {
417 if (any_node_cnt) {
418 output += (" - [\033[;31many\033[0m] ");
419 } else {
420 output += shambase::format(" - [\033[;31m{:03}\033[0m] ", t.node_count);
421 }
422 output += "\033[;31m" + t.name + "\033[0m\n";
423 }
424
425 printf("%s", output.c_str());
426 };
427
428 using namespace shamtest::details;
429
430 std::vector<u32> selected_tests;
431
432 auto run_only_check = [&](std::string test_name) -> bool {
433 if (cfg.run_only) {
434 return *cfg.run_only == test_name;
435 } else {
436 return true;
437 }
438 };
439
440 auto test_loop = [&](TestType t) {
441 for (u32 i = 0; i < static_init_vec_tests.size(); i++) {
442 if (static_init_vec_tests[i].type == t) {
443
444 bool run_test = can_run(static_init_vec_tests[i])
445 && run_only_check(static_init_vec_tests[i].name);
446
447 ON_RANK_0(print_test(static_init_vec_tests[i], run_test));
448
449 if (run_test) {
450 selected_tests.push_back(i);
451 }
452 }
453 }
454 };
455
456 ON_RANK_0(printf("\n------------ Tests list --------------\n"));
457 if (run_benchmark_test) {
458 ON_RANK_0(printf("--- Benchmark ---\n"));
459 test_loop(Benchmark);
460 }
461
462 if (run_benchmark_test) {
463 ON_RANK_0(printf("--- LongBenchmark ---\n"));
464 test_loop(LongBenchmark);
465 }
466
467 if (run_validation_test) {
468 ON_RANK_0(printf("--- ValidationTest ---\n"));
469 test_loop(ValidationTest);
470 }
471
472 if (run_longvalidation_test) {
473 ON_RANK_0(printf("--- LongValidationTest ---\n"));
474 test_loop(LongValidationTest);
475 }
476
477 if (run_unit_test) {
478 ON_RANK_0(printf("--- Unittest ---\n"));
479 test_loop(Unittest);
480 }
481 ON_RANK_0(printf("--------------------------------------\n\n"));
482
483 return selected_tests;
484 }
485
486 int run_all_tests(int argc, char *argv[], TestConfig cfg) {
487 StackEntry stack{};
488
489 is_full_output_mode = cfg.full_output;
490
491 mpi::barrier(MPI_COMM_WORLD);
492 std::vector<u32> selected_tests = select_print_tests(cfg);
493 mpi::barrier(MPI_COMM_WORLD);
494
495 u32 test_loc_cnt = 0;
496
497 bool has_error = false;
498
499 logger::info_ln("Test", "start python interpreter");
500 py::initialize_interpreter();
501
502 ON_RANK_0(shamcomm::logs::print_faint_row());
504 shambindings::set_sys_argv(argc, argv);
505 ON_RANK_0(shamcomm::logs::print_faint_row());
506
507 // import shamrock in pybind
508 py::exec(R"(
509 import shamrock
510 )");
511
512 std::filesystem::create_directories("tests/figures");
513
514 using namespace shamtest::details;
515
516 ON_RANK_0(logger::raw_ln("Running tests : "));
517 ON_RANK_0(shamcomm::logs::print_faint_row());
518
519 std::vector<TestResult> results;
520 for (u32 i : selected_tests) {
521
523
524 _start_test_print(test, test_loc_cnt, selected_tests.size());
525
526 [[maybe_unused]] shambase::scoped_exception_gen_callback scoped_callback(nullptr);
527
528 mpi::barrier(MPI_COMM_WORLD);
529 shambase::Timer timer;
530 timer.start();
531 TestResult res = test.run();
532 timer.end();
533 mpi::barrier(MPI_COMM_WORLD);
534
535 usize gather_bytecount = 0;
536 std::vector<TestResult> gathered = gather_tests({res}, gather_bytecount);
537 if (shamcomm::world_rank() == 0) {
538 logger::raw_ln("Test result gathered :", gather_bytecount, "bytes");
539 _end_test_print(gathered, timer);
540 }
541
542 results.push_back(std::move(res));
543
544 test_loc_cnt++;
545 }
546
547 logger::info_ln("Test", "close python interpreter");
548 py::finalize_interpreter();
549
550 usize gather_bytecount = 0;
551 results = gather_tests(std::move(results), gather_bytecount);
552
553 if (shamcomm::world_rank() == 0) {
554 logger::print_faint_row();
555 logger::raw_ln("Test result gathered :", gather_bytecount, "bytes");
556 }
557
558 for (TestResult &res : results) {
559 has_error = has_error || is_test_failed(res);
560 }
561
562 _print_summary(results);
563
564 if (cfg.json_output) {
565 write_json_report(results, *cfg.json_output);
566 }
567
568 write_tex_report(results, has_error);
569
570 i32 errcode;
571 if (has_error) {
572 errcode = 255;
573 } else {
574 errcode = 0;
575 }
576
577 mpi::barrier(MPI_COMM_WORLD);
578
579 if (shamcomm::world_rank() == 0) {
580 logger::raw_ln("Tests done exiting ... exitcode =", errcode);
581 }
582 mpi::barrier(MPI_COMM_WORLD);
584
585 return errcode;
586 }
587
588 void gen_test_list(std::string_view outfile) {
589 // logger::raw_ln("Test list ...", outfile);
590
591 using namespace details;
592
593 std::array rank_list{1, 2, 3, 4};
594
595 auto get_pref_type = [](TestType t) -> std::string {
596 switch (t) {
597 case Benchmark : return "Benchmark";
598 case LongBenchmark : return "LongBenchmark";
599 case ValidationTest : return "ValidationTest";
600 case LongValidationTest: return "LongValidationTest";
601 case Unittest : return "Unittest";
602 }
603 };
604
605 auto get_arg = [](TestType t) -> std::string {
606 switch (t) {
607 case Benchmark : return "--benchmark";
608 case LongBenchmark : return "--long-test --benchmark";
609 case ValidationTest : return "--validation";
610 case LongValidationTest: return "--long-test --validation";
611 case Unittest : return "--unittest";
612 }
613 };
614
615 auto get_test_name = [&](Test t, int ranks) -> std::string {
616 std::string name = get_pref_type(t.type) + "/" + t.name
617 + shambase::format(
618 "(ranks={})"
619 //"{}"
620 ,
621 ranks);
622 // shambase::replace_all(name, "/", "");
623 return name;
624 };
625
626 std::ofstream filestream;
627 filestream.open(std::string(outfile));
628
629 std::vector<std::string> cmake_test_list;
630
631 auto add_test = [&](Test t, int ranks) {
632 std::string tname = get_test_name(t, ranks);
633 cmake_test_list.push_back(tname);
634
635 std::string ret = "add_test(\"";
636 ret += tname;
637 ret += "\"";
638 if (ranks > 1) {
639 ret += " mpirun -n " + std::to_string(ranks) + " ../shamrock_test --sycl-cfg 0:0";
640 } else {
641 ret += " ../shamrock_test --sycl-cfg 0:0";
642 }
643 ret += " --run-only \"" + std::string(t.name) + "\"";
644 ret += " " + get_arg(t.type);
645 ret += ")\n";
646 filestream << ret;
647 };
648
649 for (const Test &t : static_init_vec_tests) {
650 if (t.type == Benchmark || t.type == LongBenchmark)
651 continue;
652 if (t.node_count == -1) {
653 for (int ncount : rank_list) {
654 add_test(t, ncount);
655 }
656 } else {
657 add_test(t, t.node_count);
658 }
659 }
660
661 filestream << "\n";
662
663 auto REF_FILES_PATH = shamcmdopt::getenv_str("REF_FILES_PATH");
664
665 if (REF_FILES_PATH) {
666 filestream << "set_tests_properties(\n";
667 for (auto tname : cmake_test_list) {
668 filestream << " \"" << tname << "\"\n";
669 }
670 filestream << " PROPERTIES\n";
671 filestream << " ENVIRONMENT \"REF_FILES_PATH=" + *REF_FILES_PATH << "\"\n";
672 filestream << ")\n";
673 }
674
675 filestream.close();
676 }
677
678} // namespace shamtest
This header does the MPI include and wrap MPI calls.
Header file describing a Node Instance.
void close()
close the NodeInstance Aka : Finalize both MPI & SYCL
header describing return type of a test, and the type of the test
TestType
Describe the type of the performed test.
std::uint32_t u32
32 bit unsigned integer
std::size_t usize
size_t alias
std::int32_t i32
32 bit integer
Class Timer measures the time elapsed since the timer was started.
Definition time.hpp:96
std::string get_time_str() const
Converts the stored nanosecond time to a string representation.
Definition time.hpp:117
void end()
Stops the timer and stores the elapsed time in nanoseconds.
Definition time.hpp:111
void start()
Starts the timer.
Definition time.hpp:106
Scoped exception generator callback.
Definition exception.hpp:65
This header file contains utility functions related to exception handling in the code.
MPI string gather / allgather helpers (declarations; implementations in shamalgs/src/collective/gathe...
void gather_basic_str(const std::basic_string< byte > &send_vec, std::basic_string< byte > &recv_vec)
same as gather_str but with std::basic_string
Namespace for internal details of the logs module.
void stream_read_vector(std::basic_stringstream< byte > &stream, std::vector< T > &vec)
read a vector from the bytestream Note : this appends read objects to the vector without resetting it
void write_string_to_file(std::string filename, std::string s)
dump a string to a file
Definition string.hpp:168
void throw_with_loc(std::string message, SourceLocation loc=SourceLocation{})
Throw an exception and append the source location to it.
std::string increase_indent(std::string in, std::string delim="\n ")
Increase indentation of a string.
Definition string.hpp:197
void stream_write_vector(std::basic_stringstream< byte > &stream, std::vector< T > &vec)
write the vector into the bytestream
std::optional< std::string > getenv_str(const char *env_var)
Get the content of the environment variable if it exist.
Definition env.cpp:24
bool is_ci_github_actions()
Check if the environment variable GITHUB_ACTIONS is set.
Definition ci_env.cpp:40
i32 world_rank()
Gives the rank of the current process in the MPI communicator.
Definition worldInfo.cpp:40
i32 world_size()
Gives the size of the MPI communicator.
Definition worldInfo.cpp:38
implementation details of the test library
Definition DataNode.hpp:23
std::vector< Test > static_init_vec_tests
Static init vector containing the list of all the tests in the code see : programming guide : Static ...
Definition shamtest.hpp:36
std::string make_test_report_tex(std::vector< TestResult > &results, bool mark_fail)
Make the tex report.
namespace containing stuff related to the test library
Definition DataNode.hpp:23
void gen_test_list(std::string_view outfile)
output test list to a file
int run_all_tests(int argc, char *argv[], TestConfig cfg)
run all the tests
Pybind11 include and definitions.
main include file for testing
This file contains the definition for the stacktrace related functionality.
void modify_py_sys_path(bool do_print)
Modify Python sys.path to point to one detected during cmake invocation.
void set_sys_argv(int argc, char *argv[])
set the value of sys.argv
Configuration of the test runner.
Definition shamtest.hpp:63
std::optional< std::string > run_only
Run only regex to select tests.
Definition shamtest.hpp:82
bool run_long_tests
run also long tests
Definition shamtest.hpp:77
bool run_benchmark
run benchmarks
Definition shamtest.hpp:80
bool run_unittest
run unittests
Definition shamtest.hpp:78
bool full_output
Should display all logs including all asserts.
Definition shamtest.hpp:69
std::optional< std::string > json_output
Should output a json report.
Definition shamtest.hpp:75
bool run_validation
run validation tests
Definition shamtest.hpp:79
Informations about a test.
Definition Test.hpp:26
i32 node_count
Node count of the test.
Definition Test.hpp:29
std::string name
Name of the test.
Definition Test.hpp:28
TestType type
Type of test.
Definition Test.hpp:27
header file to manage sycl
typedefs and macros
Functions related to the MPI communicator.
#define ON_RANK_0(x)
Macro to execute code only on rank 0.
Definition worldInfo.hpp:73