62 std::array<f64, 6> prefixes_val = {1.e3, 1.e6, 1.e9, 1.e12, 1.e15, 1.e18};
81 logger::raw_ln(
"Running micro benchmarks:");
87 microbench::p2p_bandwidth(wr1, wr2);
89 microbench::p2p_latency(wr1, wr2);
91 microbench::saxpy<f32>();
92 microbench::saxpy<f64>();
93 microbench::saxpy<f32_2>();
94 microbench::saxpy<f64_2>();
95 microbench::saxpy<f32_3>();
96 microbench::saxpy<f64_3>();
97 microbench::saxpy<f32_4>();
98 microbench::saxpy<f64_4>();
99 microbench::fma_chains_rotation<f32>();
100 microbench::fma_chains_rotation<f64>();
101 microbench::fma_chains_rotation<f32_2>();
102 microbench::fma_chains_rotation<f64_2>();
103 microbench::fma_chains_rotation<f32_3>();
104 microbench::fma_chains_rotation<f64_3>();
105 microbench::fma_chains_rotation<f32_4>();
106 microbench::fma_chains_rotation<f64_4>();
107 microbench::vector_allgather(1);
108 microbench::vector_allgather(8);
109 microbench::vector_allgather(64);
110 microbench::vector_allgather(128);
111 microbench::vector_allgather(150);
112 microbench::vector_allgather(1024);
120 u64 length = 1024UL * 1014UL * 8UL;
124 std::vector<MPI_Request> rqs;
128 bool is_used =
false;
132 mpi::barrier(MPI_COMM_WORLD);
133 f64 t_start = MPI_Wtime();
135 if (wr == wr_sender) {
136 rqs.push_back(MPI_Request{});
137 u32 rq_index = rqs.size() - 1;
138 auto &rq = rqs[rq_index];
139 shamcomm::mpi::Isend(
140 buf_send.get_ptr(), length, MPI_BYTE, wr_receiv, 0, MPI_COMM_WORLD, &rq);
144 if (wr == wr_receiv) {
147 buf_recv.get_ptr(), length, MPI_BYTE, wr_sender, 0, MPI_COMM_WORLD, &s);
154 std::vector<MPI_Status> st_lst(rqs.size());
155 if (rqs.size() > 0) {
156 shamcomm::mpi::Waitall(rqs.size(), rqs.data(), st_lst.data());
158 f64 t_end = MPI_Wtime();
159 t += t_end - t_start;
161 }
while (shamalgs::collective::allreduce_min(t) < 1);
163 f64 bw =
f64(length * loops) / t;
165 microbench_results[
"p2p_bandwidth"] = bw;
168 auto [prefix, val] = format_result(bw);
171 " - p2p bandwidth : {} {}B.s^-1 (ranks : {} -> {}) (loops : {})",
185 "can not launch this test with same ranks");
199 bool is_used =
false;
203 mpi::barrier(MPI_COMM_WORLD);
204 f64 t_start = MPI_Wtime();
208 shamcomm::mpi::Send(buf_send.get_ptr(), length, MPI_BYTE, wr2, 0, MPI_COMM_WORLD);
209 shamcomm::mpi::Recv(buf_recv.get_ptr(), length, MPI_BYTE, wr2, 1, MPI_COMM_WORLD, &s);
215 shamcomm::mpi::Recv(buf_recv.get_ptr(), length, MPI_BYTE, wr1, 0, MPI_COMM_WORLD, &s);
216 shamcomm::mpi::Send(buf_send.get_ptr(), length, MPI_BYTE, wr1, 1, MPI_COMM_WORLD);
223 f64 t_end = MPI_Wtime();
224 t += t_end - t_start;
228 }
while (shamalgs::collective::allreduce_min(bench_timer.
elasped_sec()) < 1);
230 f64 latency = t /
f64(loops);
231 microbench_results[
"p2p_latency"] = latency;
236 " - p2p latency : {:.4e} s (ranks : {} <-> {}) (loops : {})",
246 int Tsize =
sizeof(T);
248 std::string type_name;
250 if constexpr (std::is_same_v<T, f32>) {
255 }
else if constexpr (std::is_same_v<T, f64>) {
260 }
else if constexpr (std::is_same_v<T, f32_2>) {
262 init_x = {1.0f, 1.0f};
263 init_y = {2.0f, 2.0f};
265 }
else if constexpr (std::is_same_v<T, f64_2>) {
270 }
else if constexpr (std::is_same_v<T, f32_3>) {
272 init_x = {1.0f, 1.0f, 1.0f};
273 init_y = {2.0f, 2.0f, 2.0f};
274 a = {2.0f, 2.0f, 2.0f};
275 }
else if constexpr (std::is_same_v<T, f64_3>) {
277 init_x = {1.0, 1.0, 1.0};
278 init_y = {2.0, 2.0, 2.0};
280 }
else if constexpr (std::is_same_v<T, f32_4>) {
282 init_x = {1.0f, 1.0f, 1.0f, 1.0f};
283 init_y = {2.0f, 2.0f, 2.0f, 2.0f};
284 a = {2.0f, 2.0f, 2.0f, 2.0f};
285 }
else if constexpr (std::is_same_v<T, f64_4>) {
287 init_x = {1.0, 1.0, 1.0, 1.0};
288 init_y = {2.0, 2.0, 2.0, 2.0};
289 a = {2.0, 2.0, 2.0, 2.0};
294 auto bench_step = [&](
int N) {
295 return sham::benchmarks::saxpy_bench<T>(
296 instance::get_compute_scheduler_ptr(), N, init_x, init_y, a, Tsize, N < (1 << 17));
299 auto benchmark = [&]() {
300 size_t N = (1 << 15);
303 auto &dev_ptr = dev_ctx.device;
307 = std::min<size_t>(dev.prop.max_mem_alloc_size_dev, dev.prop.global_mem_size);
308 double max_size = double(max_alloc) / (Tsize * 4);
309 if (max_size >= (1 << 30)) {
310 max_size = (1 << 30);
315 for (; N <= (1 << 30) && static_cast<double>(N) <= max_size; N *= 2) {
321 if (result.seconds > 1e-3) {
329 auto result = benchmark();
331 f64 bw = result.bandwidth * 1e9;
333 f64 min_bw = shamalgs::collective::allreduce_min(bw);
334 f64 max_bw = shamalgs::collective::allreduce_max(bw);
335 f64 sum_bw = shamalgs::collective::allreduce_sum(bw);
338 microbench_results[
"saxpy_" + type_name] = sum_bw;
341 auto [prefix, val] = format_result(sum_bw);
344 " - saxpy ({}) : {} {}B.s^-1 (min = {:.1e}, max = {:.1e}, avg = {:.1e}) "
352 result.seconds * 1e3,
362 = sham::benchmarks::fma_chains_bench<T>(instance::get_compute_scheduler_ptr(), N, 0.2);
364 std::string type_name;
365 f64 flops_multiplier = 1;
366 if constexpr (std::is_same_v<T, f32>) {
368 flops_multiplier = 1;
369 }
else if constexpr (std::is_same_v<T, f64>) {
371 flops_multiplier = 1;
372 }
else if constexpr (std::is_same_v<T, f32_2>) {
374 flops_multiplier = 2;
375 }
else if constexpr (std::is_same_v<T, f64_2>) {
377 flops_multiplier = 2;
378 }
else if constexpr (std::is_same_v<T, f32_3>) {
380 flops_multiplier = 3;
381 }
else if constexpr (std::is_same_v<T, f64_3>) {
383 flops_multiplier = 3;
384 }
else if constexpr (std::is_same_v<T, f32_4>) {
386 flops_multiplier = 4;
387 }
else if constexpr (std::is_same_v<T, f64_4>) {
389 flops_multiplier = 4;
394 f64 min_flop = shamalgs::collective::allreduce_min(result.flops);
395 f64 max_flop = shamalgs::collective::allreduce_max(result.flops);
396 f64 sum_flop = shamalgs::collective::allreduce_sum(result.flops);
399 microbench_results[
"fma_chains_" + type_name] = sum_flop * flops_multiplier;
402 auto [prefix, val] = format_result(sum_flop * flops_multiplier);
405 " - fma_chains ({}) : {} {}flops (min = {:.1e}, max = {:.1e}, avg = {:.1e}) "
406 "({:.1e} ms, rotations = {})",
410 min_flop * flops_multiplier,
411 max_flop * flops_multiplier,
412 avg_flop * flops_multiplier,
413 result.seconds * 1e3,
421 std::vector<u64> send_data(el_per_rank);
423 std::vector<u64> recv_data;
428 auto benchmark_step = [&]() {
429 shamcomm::mpi::Barrier(MPI_COMM_WORLD);
430 f64 t_start = MPI_Wtime();
431 shamalgs::collective::vector_allgatherv(send_data, recv_data, MPI_COMM_WORLD);
432 f64 t_end = MPI_Wtime();
433 t += t_end - t_start;
439 }
while (shamalgs::collective::allreduce_min(t) < 0.1);
443 f64 min_t = shamalgs::collective::allreduce_min(t);
444 f64 max_t = shamalgs::collective::allreduce_max(t);
445 f64 sum_t = shamalgs::collective::allreduce_sum(t);
448 microbench_results[
"vector_allgather_u64_" + std::to_string(el_per_rank)] = avg_t;
453 " - vector_allgather (u64, n={:4}) : {:.3e} s (min = {:.2e}, max = {:.2e}, loops = "