33 static constexpr T digit_bit_places = bitlen_T / digit_bit_len;
34 static constexpr T digit_count = (1U << digit_bit_len);
35 static constexpr T value_count = digit_bit_places * digit_count;
36 static constexpr T digit_mask = digit_count - 1;
39 digit_bit_places * digit_bit_len == bitlen_T,
"the conversion should be correct");
42 inline static void fetch_add_bin(Acc accessor, T digit_val, T digit_place) {
43 using atomic_ref_T = sycl::atomic_ref<
45 sycl::memory_order_relaxed,
46 sycl::memory_scope_work_group,
47 sycl::access::address_space::local_space>;
49 atomic_ref_T(accessor[digit_val + digit_place * digit_count]).fetch_add(1U);
52 inline static T get_digit_value(T value, T digit_place) {
53 return digit_mask & (value >> (digit_place * digit_bit_len));
57 inline static void add_bin_key(Acc accessor, T value_to_bin) {
60 for (T digit_place = 0; digit_place < digit_bit_places; digit_place++) {
61 T shifted = get_digit_value(value_to_bin, digit_place);
63 fetch_add_bin(accessor, shifted, digit_place);
67 template<u32 group_size,
class Tkey>
68 inline static sycl::buffer<u32> make_digit_histogram(
69 sycl::queue &q, sycl::buffer<Tkey> &buf_key,
u32 len) {
73 group_cnt = group_cnt + (group_cnt % 4);
74 u32 corrected_len = group_cnt * group_size;
76 sycl::buffer<u32> digit_histogram(value_count);
83 q.submit([&, len](sycl::handler &cgh) {
84 sycl::accessor keys{buf_key, cgh, sycl::read_only};
85 sycl::accessor histogram{digit_histogram, cgh, sycl::read_write};
87 sycl::local_accessor<u32, 1> local_histogram{value_count, cgh};
90 sycl::nd_range<1>{corrected_len, group_size}, [=](sycl::nd_item<1> id) {
91 u32 local_id =
id.get_local_id(0);
92 u32 group_tile_id =
id.get_group_linear_id();
93 u32 global_id = group_tile_id * group_size + local_id;
96 for (
u32 idx = 0; idx < value_count; idx++) {
97 local_histogram[idx] = 0;
100 id.barrier(sycl::access::fence_space::local_space);
103 if (global_id < len) {
104 add_bin_key(local_histogram, keys[global_id]);
107 id.barrier(sycl::access::fence_space::local_space);
109 for (
u32 i = local_id; i < value_count; i += group_size) {
110 u32 dcount = local_histogram[i];
114 using atomic_ref_t = sycl::atomic_ref<
116 sycl::memory_order_relaxed,
117 sycl::memory_scope_device,
118 sycl::access::address_space::global_space>;
120 atomic_ref_t(histogram[i]).fetch_add(dcount);
126 return digit_histogram;