Shamrock 2025.10.0
Astrophysical Code
Loading...
Searching...
No Matches
InvariantParallelGenerator.hpp
Go to the documentation of this file.
1// -------------------------------------------------------//
2//
3// SHAMROCK code for hydrodynamics
4// Copyright (c) 2021-2026 Timothée David--Cléris <tim.shamrock@proton.me>
5// SPDX-License-Identifier: CeCILL Free Software License Agreement v2.1
6// Shamrock is licensed under the CeCILL 2.1 License, see LICENSE for more information
7//
8// -------------------------------------------------------//
9
10#pragma once
11
22#include <random>
23
24namespace shamalgs::collective {
25
27 template<class Engine = std::mt19937_64>
29 Engine eng_global;
30 u64 nval_max;
31 u64 nval_current;
32 bool done;
33
34 void skip(u64 n) {
35 u64 remaining_n = nval_max - nval_current;
36 u64 to_skip = std::min(remaining_n, n);
37
38 eng_global.discard(to_skip);
39 nval_current += to_skip;
40 if (nval_current == nval_max) {
41 done = true;
42 }
43 }
44
45 std::vector<u64> next_n_sequential(u64 val_count) {
46
47 if (is_done()) {
48 return {};
49 }
50
51 u64 to_generate = std::min(val_count, nval_max - nval_current);
52
53 std::vector<u64> ret(to_generate);
54 for (u64 i = 0; i < to_generate; i++) {
55 ret[i] = eng_global();
56 }
57
58 nval_current += to_generate;
59 if (nval_current == nval_max) {
60 done = true;
61 }
62 return ret;
63 }
64
65 std::vector<u64> next_n_parallel(u64 val_count) {
66
67 if (is_done()) {
68 return {};
69 }
70
71 auto gen_info = shamalgs::collective::fetch_view(val_count);
72
73 // here i keep the temp variable for clarity
74 u64 skip_start = gen_info.head_offset;
75 u64 gen_cnt = val_count;
76 u64 skip_end = gen_info.total_byte_count - val_count - gen_info.head_offset;
77
78 shamlog_debug_ln(
79 "InvariantParallelGenerator",
80 "generate : ",
81 skip_start,
82 gen_cnt,
83 skip_end,
84 "total",
85 skip_start + gen_cnt + skip_end);
86
87 skip(skip_start);
88 std::vector<u64> ret = next_n_sequential(gen_cnt);
89 skip(skip_end);
90 return ret;
91 }
92
93 public:
94 InvariantParallelGenerator(Engine eng, u64 nval_max = u64_max)
95 : eng_global(eng), nval_max(nval_max), nval_current(0), done(false) {
96 if (nval_max == 0) {
97 done = true;
98 }
99 }
100
101 InvariantParallelGenerator(u64 seed, u64 nval_max = u64_max)
102 : InvariantParallelGenerator(Engine(seed), nval_max) {}
103
119 std::vector<u64> next_n(u64 val_count, bool sequential = false) {
120 if (sequential) {
121 u64 sum_ranks = collective::allreduce_sum<u64>(val_count);
122 return next_n_sequential(sum_ranks);
123 } else {
124 return next_n_parallel(val_count);
125 }
126 }
127
129 bool is_done() { return done; }
130
133 Engine duplicated_eng = eng_global;
134 u64 check_val = duplicated_eng();
135
136 std::vector<u64> collected_data{};
137 shamalgs::collective::vector_allgatherv({check_val}, collected_data, MPI_COMM_WORLD);
138
139 for (u64 val : collected_data) {
140 if (val != check_val) {
141 return false;
142 }
143 }
144 return true;
145 }
146 };
147} // namespace shamalgs::collective
std::uint64_t u64
64 bit unsigned integer
A parallel generator that will spit the same sequence regardless of the number of ranks.
std::vector< u64 > next_n(u64 val_count, bool sequential=false)
Generate the next val_count values.
bool all_ranks_are_in_sync()
check if all ranks have the same generator state
std::vector< int > vector_allgatherv(const std::vector< T > &send_vec, const MPI_Datatype &send_type, std::vector< T > &recv_vec, const MPI_Datatype &recv_type, const MPI_Comm comm)
allgatherv on vector with size query (size querying variant of vector_allgatherv_ks) //TODO add fault...
Definition exchanges.hpp:98
constexpr u64 u64_max
u64 max value