Shamrock 2025.10.0
Astrophysical Code
Loading...
Searching...
No Matches
DigitBinner.hpp
Go to the documentation of this file.
1// -------------------------------------------------------//
2//
3// SHAMROCK code for hydrodynamics
4// Copyright (c) 2021-2026 Timothée David--Cléris <tim.shamrock@proton.me>
5// SPDX-License-Identifier: CeCILL Free Software License Agreement v2.1
6// Shamrock is licensed under the CeCILL 2.1 License, see LICENSE for more information
7//
8// -------------------------------------------------------//
9
10#pragma once
11
19#include "shambase/integer.hpp"
21#include "shamalgs/memory.hpp"
22#include "shamalgs/numeric.hpp"
23#include "shambackends/sycl.hpp"
25#include <numeric>
26
28
29 template<class T, u32 digit_bit_len>
31 public:
32 static constexpr T bitlen_T = shambase::bitsizeof<T>;
33 static constexpr T digit_bit_places = bitlen_T / digit_bit_len;
34 static constexpr T digit_count = (1U << digit_bit_len);
35 static constexpr T value_count = digit_bit_places * digit_count;
36 static constexpr T digit_mask = digit_count - 1;
37
38 static_assert(
39 digit_bit_places * digit_bit_len == bitlen_T, "the conversion should be correct");
40
41 template<class Acc>
42 inline static void fetch_add_bin(Acc accessor, T digit_val, T digit_place) {
43 using atomic_ref_T = sycl::atomic_ref<
44 u32,
45 sycl::memory_order_relaxed,
46 sycl::memory_scope_work_group,
47 sycl::access::address_space::local_space>;
48
49 atomic_ref_T(accessor[digit_val + digit_place * digit_count]).fetch_add(1U);
50 }
51
52 inline static T get_digit_value(T value, T digit_place) {
53 return digit_mask & (value >> (digit_place * digit_bit_len));
54 }
55
56 template<class Acc>
57 inline static void add_bin_key(Acc accessor, T value_to_bin) {
58
59#pragma unroll
60 for (T digit_place = 0; digit_place < digit_bit_places; digit_place++) {
61 T shifted = get_digit_value(value_to_bin, digit_place);
62
63 fetch_add_bin(accessor, shifted, digit_place);
64 }
65 }
66
67 template<u32 group_size, class Tkey>
68 inline static sycl::buffer<u32> make_digit_histogram(
69 sycl::queue &q, sycl::buffer<Tkey> &buf_key, u32 len) {
70
71 u32 group_cnt = shambase::group_count(len, group_size);
72
73 group_cnt = group_cnt + (group_cnt % 4);
74 u32 corrected_len = group_cnt * group_size;
75
76 sycl::buffer<u32> digit_histogram(value_count);
77
78 shamalgs::memory::buf_fill_discard(q, digit_histogram, 0U);
79
80 // logger::raw_ln("digit binning");
81 // memory::print_buf(digit_histogram, value_count, digit_count, "{:4} ");
82
83 q.submit([&, len](sycl::handler &cgh) {
84 sycl::accessor keys{buf_key, cgh, sycl::read_only};
85 sycl::accessor histogram{digit_histogram, cgh, sycl::read_write};
86
87 sycl::local_accessor<u32, 1> local_histogram{value_count, cgh};
88
89 cgh.parallel_for(
90 sycl::nd_range<1>{corrected_len, group_size}, [=](sycl::nd_item<1> id) {
91 u32 local_id = id.get_local_id(0);
92 u32 group_tile_id = id.get_group_linear_id();
93 u32 global_id = group_tile_id * group_size + local_id;
94
95 if (local_id == 0) {
96 for (u32 idx = 0; idx < value_count; idx++) {
97 local_histogram[idx] = 0;
98 }
99 }
100 id.barrier(sycl::access::fence_space::local_space);
101
102 // load from global buffer
103 if (global_id < len) {
104 add_bin_key(local_histogram, keys[global_id]);
105 }
106
107 id.barrier(sycl::access::fence_space::local_space);
108
109 for (u32 i = local_id; i < value_count; i += group_size) {
110 u32 dcount = local_histogram[i];
111
112 if (dcount != 0) {
113
114 using atomic_ref_t = sycl::atomic_ref<
115 u32,
116 sycl::memory_order_relaxed,
117 sycl::memory_scope_device,
118 sycl::access::address_space::global_space>;
119
120 atomic_ref_t(histogram[i]).fetch_add(dcount);
121 }
122 }
123 });
124 });
125
126 return digit_histogram;
127 }
128 };
129
130} // namespace shamalgs::algorithm::details
std::uint32_t u32
32 bit unsigned integer
namespace to store algorithms implemented by shamalgs
void buf_fill_discard(sycl::queue &q, sycl::buffer< T > &buf, T value)
Fill a buffer with a given value (sycl::no_init mode)
Definition memory.hpp:159
constexpr u32 group_count(u32 len, u32 group_size)
Calculates the number of groups based on the length and group size.
Definition integer.hpp:125
void throw_with_loc(std::string message, SourceLocation loc=SourceLocation{})
Throw an exception and append the source location to it.
main include file for memory algorithms
Traits for C++ types.