Shamrock 2025.10.0
Astrophysical Code
Loading...
Searching...
No Matches
fma_chains.hpp
Go to the documentation of this file.
1// -------------------------------------------------------//
2//
3// SHAMROCK code for hydrodynamics
4// Copyright (c) 2021-2026 Timothée David--Cléris <tim.shamrock@proton.me>
5// SPDX-License-Identifier: CeCILL Free Software License Agreement v2.1
6// Shamrock is licensed under the CeCILL 2.1 License, see LICENSE for more information
7//
8// -------------------------------------------------------//
9
10#pragma once
11
18#include "shambase/assert.hpp"
19#include "shambase/time.hpp"
22#include "shambackends/math.hpp"
23
24namespace sham::benchmarks {
25
40 template<class T>
41 inline void fma_chains(u32 i, int nrotation, T y0, T *__restrict in, T *__restrict out) {
42#define MAD_4(x, y) \
43 x = y * x + y; \
44 y = x * y + x; \
45 x = y * x + y; \
46 y = x * y + x;
47#define MAD_16(x, y) \
48 MAD_4(x, y); \
49 MAD_4(x, y); \
50 MAD_4(x, y); \
51 MAD_4(x, y);
52
53 T x = in[i];
54 T y = y0;
55 for (int j = 0; j < nrotation; j++) {
56 MAD_16(x, y);
57 }
58 out[i] = y;
59
60#undef MAD_4
61#undef MAD_16
62 }
63
71
84 template<class T>
86 DeviceScheduler_ptr sched, int N, f64 time_threshold) {
87
88 sham::DeviceQueue &q = sched->get_queue();
89
90 sham::DeviceBuffer<T> x = {size_t(N), sched};
91 sham::DeviceBuffer<T> y = {size_t(N), sched};
92
93 const T x0 = T{1.1};
94 const T y0 = -x0;
95
96 x.fill(x0);
97 y.fill(y0);
98
99 sham::EventList depends_list;
100
101 auto x_ptr = x.get_write_access(depends_list);
102 auto y_ptr = y.get_write_access(depends_list);
103
104 depends_list.wait();
105
106 u32 nrotation = 8;
107 double sec = 0;
108
109 auto run_bench = [&q, &N, &x_ptr, &y_ptr, y0](u32 nrotation) -> f64 {
110 sham::EventList empty_list{};
111
113 t.start();
114 auto e = q.submit(empty_list, [=](sycl::handler &cgh) {
115 cgh.parallel_for(sycl::range<1>{size_t(N)}, [=](sycl::item<1> item) {
116 fma_chains(item.get_linear_id(), nrotation, y0, x_ptr, y_ptr);
117 });
118 });
119 e.wait();
120 t.end();
121
122 return t.elasped_sec();
123 };
124
125 // warmup kernel
126 run_bench(4);
127
128 double ref = run_bench(0);
129
130 for (;;) {
131
132 sec = run_bench(nrotation);
133
134 if (sec >= time_threshold || nrotation >= 256 * 256 * 4) {
135 break;
136 }
137
138 nrotation *= 2;
139 }
140
141 x.complete_event_state(sycl::event{});
142 y.complete_event_state(sycl::event{});
143
144 sec -= ref;
145
146 u64 flop_per_thread = u64(nrotation) * 2_u64 * 16_u64;
147 double flop_count = double(N) * flop_per_thread;
148 double flops = flop_count / (sec);
149
150 return {SourceLocation{}.loc.function_name(), sec, flops, nrotation};
151 }
152
153} // namespace sham::benchmarks
double f64
Alias for double.
std::uint32_t u32
32 bit unsigned integer
std::uint64_t u64
64 bit unsigned integer
Shamrock assertion utility.
A buffer allocated in USM (Unified Shared Memory)
void complete_event_state(sycl::event e) const
Complete the event state of the buffer.
T * get_write_access(sham::EventList &depends_list, SourceLocation src_loc=SourceLocation{})
Get a read-write pointer to the buffer's data.
void fill(T value, std::array< size_t, 2 > idx_range)
Fill a subpart of the buffer with a given value.
A SYCL queue associated with a device and a context.
sycl::event submit(Fct &&fct)
Submits a kernel to the SYCL queue.
Class to manage a list of SYCL events.
Definition EventList.hpp:31
void wait()
Wait for all events in the list to be finished.
Definition EventList.hpp:57
Class Timer measures the time elapsed since the timer was started.
Definition time.hpp:96
void end()
Stops the timer and stores the elapsed time in nanoseconds.
Definition time.hpp:111
f64 elasped_sec() const
Converts the stored nanosecond time to a floating point representation in seconds.
Definition time.hpp:123
void start()
Starts the timer.
Definition time.hpp:106
fma_chains_result fma_chains_bench(DeviceScheduler_ptr sched, int N, f64 time_threshold)
Run the fma_chains benchmark.
void fma_chains(u32 i, int nrotation, T y0, T *__restrict in, T *__restrict out)
Kernel for the fma_chains benchmark.
provide information about the source location
Structure containing the results of an fma_chains benchmark.
std::string func_name
Name of the function.
f64 seconds
Computation time in seconds.
u32 nrotations
Number of rotation performed.
constexpr const char * function_name() const noexcept
Returns the function name of the source location.