Shamrock 2025.10.0
Astrophysical Code
Loading...
Searching...
No Matches
compute_histogram.hpp
Go to the documentation of this file.
1// -------------------------------------------------------//
2//
3// SHAMROCK code for hydrodynamics
4// Copyright (c) 2021-2026 Timothée David--Cléris <tim.shamrock@proton.me>
5// SPDX-License-Identifier: CeCILL Free Software License Agreement v2.1
6// Shamrock is licensed under the CeCILL 2.1 License, see LICENSE for more information
7//
8// -------------------------------------------------------//
9
10#pragma once
11
20#include "shambase/string.hpp"
26#include "shamcomm/logs.hpp"
27#include <shambackends/sycl.hpp>
28#include <optional>
29#include <stdexcept>
30#include <tuple>
31#include <utility>
32#include <vector>
33
34namespace shamalgs::primitives {
35
36 namespace impl {
37 enum class histo_impl { reference, naive_gpu, gpu_team_fetching, gpu_oversubscribe };
38
40
41 bool was_init = false;
42 histo_impl impl = histo_impl::reference;
43
44 virtual std::string impl_get_alg_name() const { return "compute_histogram"; }
45
46 virtual bool impl_was_configured(const sham::DeviceScheduler_ptr &) const {
47 return was_init;
48 };
49
50 virtual std::string impl_get_config(const sham::DeviceScheduler_ptr &) const {
51 switch (impl) {
52 case histo_impl::reference : return "reference";
53 case histo_impl::naive_gpu : return "naive_gpu";
54 case histo_impl::gpu_team_fetching: return "gpu_team_fetching";
55 case histo_impl::gpu_oversubscribe: return "gpu_oversubscribe";
56 }
57 };
58
59 virtual std::string impl_get_default_config(
60 const sham::DeviceScheduler_ptr &dev_sched) const {
61 if (dev_sched->ctx->device->prop.type == sham::DeviceType::GPU) {
62 return "gpu_oversubscribe";
63 } else {
64 return "naive_gpu"; // it is portable and fast everywhere
65 }
66 };
67
68 virtual void impl_set_config(
69 const sham::DeviceScheduler_ptr &, const std::string &config) {
70 if (config == "reference") {
71 impl = histo_impl::reference;
72 } else if (config == "naive_gpu") {
73 impl = histo_impl::naive_gpu;
74 } else if (config == "gpu_team_fetching") {
75 impl = histo_impl::gpu_team_fetching;
76 } else if (config == "gpu_oversubscribe") {
77 impl = histo_impl::gpu_oversubscribe;
78 } else {
79 shambase::throw_unimplemented("unknown implementation");
80 }
81 was_init = true;
82 };
83
84 virtual std::vector<std::string> impl_get_avail_configs(
85 const sham::DeviceScheduler_ptr &) {
86 return {"reference", "naive_gpu", "gpu_team_fetching", "gpu_oversubscribe"};
87 }
88
89 public:
90 histo_impl get_impl(const sham::DeviceScheduler_ptr &dev_sched) {
91 this->ensure_init(dev_sched);
92 return impl;
93 }
94 };
95
96 inline HistogramImplControl compute_histogram_impl_control{};
97
98 template<class T, class Tbins, class... Targs, class Tfunctor>
99 inline void compute_histogram_reference(
100 const sham::DeviceBuffer<Tbins> &bin_edge_inf,
101 const sham::DeviceBuffer<Tbins> &bin_edge_sup,
102 size_t nbins,
103 size_t element_count,
104 Tfunctor &&functor,
105 sham::DeviceBuffer<T> &result,
106 const sham::DeviceBuffer<Targs> &...input_data) {
107
108 auto result_vec = result.copy_to_stdvec();
109
110 auto cpu_basic_impl = [&](const std::vector<Tbins> &bin_edge_inf,
111 const std::vector<Tbins> &bin_edge_sup,
112 const std::vector<Targs> &...in_data,
113 std::vector<T> &result) {
114 for (size_t ibin = 0; ibin < nbins; ibin++) {
115 Tbins edge_inf = bin_edge_inf[ibin];
116 Tbins edge_sup = bin_edge_sup[ibin];
117
118 T accumulator = 0;
119
120 for (size_t i = 0; i < element_count; i++) {
121 bool has_value = false;
122 auto tmp = functor(edge_inf, edge_sup, in_data[i]..., has_value);
123 if (has_value) {
124 accumulator += tmp;
125 }
126 }
127
128 result[ibin] = accumulator;
129 }
130 };
131
132 cpu_basic_impl(
133 bin_edge_inf.copy_to_stdvec(),
134 bin_edge_sup.copy_to_stdvec(),
135 input_data.copy_to_stdvec()...,
136 result_vec);
137
138 result.copy_from_stdvec(result_vec);
139 }
140
141 template<class T, class Tbins, class... Targs, class Tfunctor>
142 inline void compute_histogram_naive_gpu(
143 const sham::DeviceScheduler_ptr &dev_sched,
144 const sham::DeviceBuffer<Tbins> &bin_edge_inf,
145 const sham::DeviceBuffer<Tbins> &bin_edge_sup,
146 size_t nbins,
147 size_t element_count,
148 Tfunctor &&functor,
149 sham::DeviceBuffer<T> &result,
150 const sham::DeviceBuffer<Targs> &...input_data) {
151
153 dev_sched->get_queue(),
154 sham::MultiRef{bin_edge_inf, bin_edge_sup, input_data...},
155 sham::MultiRef{result},
156 nbins,
157 [element_count, functor](
158 u32 ibin,
159 const Tbins *__restrict bin_edge_inf,
160 const Tbins *__restrict bin_edge_sup,
161 const Targs *__restrict... in_data,
162 T *__restrict result) {
163 Tbins edge_inf = bin_edge_inf[ibin];
164 Tbins edge_sup = bin_edge_sup[ibin];
165
166 T accumulator = 0;
167
168 for (size_t i = 0; i < element_count; i++) {
169 bool has_value = false;
170 T tmp = functor(edge_inf, edge_sup, in_data[i]..., has_value);
171 if (has_value) {
172 accumulator += tmp;
173 }
174 }
175
176 result[ibin] = accumulator;
177 });
178 }
179
180 template<class T, class Tbins, class... Targs, class Tfunctor>
181 inline void compute_histogram_gpu_team_fetching(
182 const sham::DeviceScheduler_ptr &dev_sched,
183 const sham::DeviceBuffer<Tbins> &bin_edge_inf,
184 const sham::DeviceBuffer<Tbins> &bin_edge_sup,
185 size_t nbins,
186 size_t element_count,
187 Tfunctor &&functor,
188 sham::DeviceBuffer<T> &result,
189 const sham::DeviceBuffer<Targs> &...input_data) {
190
191 sham::kernel_call_hndl(
192 dev_sched->get_queue(),
193 sham::MultiRef{bin_edge_inf, bin_edge_sup, input_data...},
194 sham::MultiRef{result},
195 nbins,
196 [element_count, functor](
197 u32 nbins,
198 const Tbins *__restrict bin_edge_inf,
199 const Tbins *__restrict bin_edge_sup,
200 const Targs *__restrict... in_data,
201 T *__restrict result) {
202 return [=, in_data = std::tuple{in_data...}](sycl::handler &cgh) {
203 u32 group_size = 128;
204 u32 group_cnt = shambase::group_count(nbins, group_size);
205
206 // roundup to next multiple of 4
207 group_cnt = (group_cnt + 3) / 4 * 4;
208 u32 corrected_len = group_cnt * group_size;
209
210 auto locals
211 = sycl::local_accessor<std::tuple<Targs...>, 1>(group_size, cgh);
212
213 cgh.parallel_for(
214 sycl::nd_range<1>{corrected_len, group_size},
215 [=](sycl::nd_item<1> item) {
216 u32 local_id = item.get_local_id(0);
217 u32 group_tile_id = item.get_group_linear_id();
218 u32 ibin = group_tile_id * group_size + local_id;
219
220 bool is_valid_point = (ibin < nbins);
221 Tbins edge_inf = is_valid_point ? bin_edge_inf[ibin] : Tbins{};
222 Tbins edge_sup = is_valid_point ? bin_edge_sup[ibin] : Tbins{};
223
224 T local_sum = 0;
225
226 for (size_t i = 0; i < element_count; i += group_size) {
227
228 item.barrier(sycl::access::fence_space::local_space);
229
230 if (i + local_id < element_count) {
231 std::apply(
232 [&](auto &...in_data) {
233 locals[local_id]
234 = std::tuple{in_data[i + local_id]...};
235 },
236 in_data);
237 }
238
239 item.barrier(sycl::access::fence_space::local_space);
240
241 if (is_valid_point) {
242 for (size_t lane = 0; lane < group_size; lane++) {
243 if (i + lane >= element_count) {
244 continue;
245 }
246 bool has_value = false;
247 T tmp = std::apply(
248 [&](auto &...local_accs) {
249 return functor(
250 edge_inf,
251 edge_sup,
252 local_accs...,
253 has_value);
254 },
255 locals[lane]);
256 if (has_value) {
257 local_sum += tmp;
258 }
259 }
260 }
261
262 item.barrier(sycl::access::fence_space::local_space);
263 }
264
265 if (is_valid_point) {
266 result[ibin] = local_sum;
267 }
268 });
269 };
270 });
271 }
272
273 template<class T, class Tbins, class... Targs, class Tfunctor>
274 inline void compute_histogram_gpu_oversubscribe(
275 const sham::DeviceScheduler_ptr &dev_sched,
276 u32 group_size,
277 const sham::DeviceBuffer<Tbins> &bin_edge_inf,
278 const sham::DeviceBuffer<Tbins> &bin_edge_sup,
279 size_t nbins,
280 size_t element_count,
281 Tfunctor &&functor,
282 sham::DeviceBuffer<T> &result,
283 const sham::DeviceBuffer<Targs> &...input_data) {
284
285 sham::kernel_call_hndl(
286 dev_sched->get_queue(),
287 sham::MultiRef{bin_edge_inf, bin_edge_sup, input_data...},
288 sham::MultiRef{result},
289 nbins * group_size,
290 [element_count, functor, group_size, nbins](
291 u32 nbins_oversubscribed,
292 const Tbins *__restrict bin_edge_inf,
293 const Tbins *__restrict bin_edge_sup,
294 const Targs *__restrict... in_data,
295 T *__restrict result) {
296 return [=, in_data = std::tuple{in_data...}](sycl::handler &cgh) {
297 u32 group_cnt = shambase::group_count(nbins_oversubscribed, group_size);
298
299 // roundup to next multiple of 4
300 group_cnt = (group_cnt + 3) / 4 * 4;
301
302 u32 corrected_len = group_cnt * group_size;
303
304 cgh.parallel_for(
305 sycl::nd_range<1>{corrected_len, group_size},
306 [=](sycl::nd_item<1> item) {
307 u32 local_id = item.get_local_id(0);
308 u32 ibin = item.get_group_linear_id();
309
310 bool is_valid_point = (ibin < nbins);
311 Tbins edge_inf = is_valid_point ? bin_edge_inf[ibin] : Tbins{};
312 Tbins edge_sup = is_valid_point ? bin_edge_sup[ibin] : Tbins{};
313
314 // for each thread this will the sum of all the
315 // "func(in_data[group_size*i + local_id]) for all i"
316 T local_sum = 0;
317
318 for (size_t i = 0; i < element_count; i += group_size) {
319
320 if (i + local_id < element_count) {
321
322 bool has_value = false;
323
324 // coalesced read of the data and then
325 // compute the value to accumulate
326 T tmp = std::apply(
327 [&](auto &...in_data) {
328 return functor(
329 edge_inf,
330 edge_sup,
331 in_data[i + local_id]...,
332 has_value);
333 },
334 in_data);
335
336 if (has_value) {
337 // add it to the local sum of this thread
338 local_sum += tmp;
339 }
340 }
341 }
342
343 // we have all the terms scattered across the threads of the group,
344 // we can just accumulate the result
345 auto group_sum = sycl::reduce_over_group(
346 item.get_group(), local_sum, sycl::plus<T>{});
347
348 if (is_valid_point && local_id == 0) {
349 result[ibin] = group_sum;
350 }
351 });
352 };
353 });
354 }
355
356 } // namespace impl
357
358 template<class T, class Tbins, class... Targs, class Tfunctor>
359 inline sham::DeviceBuffer<T> compute_histogram(
360 const sham::DeviceScheduler_ptr &dev_sched,
361 const sham::DeviceBuffer<Tbins> &bin_edge_inf,
362 const sham::DeviceBuffer<Tbins> &bin_edge_sup,
363 size_t element_count,
364 Tfunctor &&functor,
365 const sham::DeviceBuffer<Targs> &...input_data) {
366
367 using namespace impl;
368
369 size_t nbins = bin_edge_inf.get_size();
370
371 if (nbins != bin_edge_sup.get_size()) {
373 "bin_edge_inf and bin_edge_sup must have the same size");
374 }
375
376 sham::DeviceBuffer<T> result(nbins, dev_sched);
377
378 switch (compute_histogram_impl_control.get_impl(dev_sched)) {
379 case histo_impl::reference:
380 compute_histogram_reference(
381 bin_edge_inf,
382 bin_edge_sup,
383 nbins,
384 element_count,
385 std::forward<Tfunctor>(functor),
386 result,
387 input_data...);
388 break;
389 case histo_impl::naive_gpu:
390 compute_histogram_naive_gpu(
391 dev_sched,
392 bin_edge_inf,
393 bin_edge_sup,
394 nbins,
395 element_count,
396 std::forward<Tfunctor>(functor),
397 result,
398 input_data...);
399 break;
400 case histo_impl::gpu_team_fetching:
401 compute_histogram_gpu_team_fetching(
402 dev_sched,
403 bin_edge_inf,
404 bin_edge_sup,
405 nbins,
406 element_count,
407 std::forward<Tfunctor>(functor),
408 result,
409 input_data...);
410 break;
411 case histo_impl::gpu_oversubscribe:
412 compute_histogram_gpu_oversubscribe(
413 dev_sched,
414 256,
415 bin_edge_inf,
416 bin_edge_sup,
417 nbins,
418 element_count,
419 std::forward<Tfunctor>(functor),
420 result,
421 input_data...);
422 break;
423 default: shambase::throw_unimplemented("unknown implementation");
424 }
425
426 return result;
427 }
428
429 template<class T>
430 inline sham::DeviceBuffer<T> compute_histogram_basic(
431 const sham::DeviceScheduler_ptr &dev_sched,
432 const sham::DeviceBuffer<T> &bin_edge_inf,
433 const sham::DeviceBuffer<T> &bin_edge_sup,
434 const sham::DeviceBuffer<T> &positions) {
435
436 size_t element_count = positions.get_size();
437
438 return compute_histogram<T>(
439 dev_sched,
440 bin_edge_inf,
441 bin_edge_sup,
442 element_count,
443 [](const T &bin_edge_inf, const T &bin_edge_sup, const T &position, bool &has_value) {
444 has_value = position >= bin_edge_inf && position < bin_edge_sup;
445 return has_value ? 1 : 0;
446 },
447 positions);
448 }
449
450} // namespace shamalgs::primitives
std::uint32_t u32
32 bit unsigned integer
A buffer allocated in USM (Unified Shared Memory)
void copy_from_stdvec(const std::vector< T > &vec)
Copy the content of a std::vector into the buffer.
std::vector< T > copy_to_stdvec() const
Copy the content of the buffer to a std::vector.
size_t get_size() const
Gets the number of elements in the buffer.
This header file contains utility functions related to exception handling in the code.
void kernel_call(sham::DeviceQueue &q, RefIn in, RefOut in_out, u32 n, Functor &&func, SourceLocation &&callsite=SourceLocation{})
Submit a kernel to a SYCL queue.
namespace for primitive algorithm (e.g. sort, scan, reductions, ...)
void throw_with_loc(std::string message, SourceLocation loc=SourceLocation{})
Throw an exception and append the source location to it.
void throw_unimplemented(SourceLocation loc=SourceLocation{})
Throw a std::runtime_error saying that the function is unimplemented.
A class that references multiple buffers or similar objects.