Shamrock 2025.10.0
Astrophysical Code
Loading...
Searching...
No Matches
radixSortOnesweep.hpp
Go to the documentation of this file.
1// -------------------------------------------------------//
2//
3// SHAMROCK code for hydrodynamics
4// Copyright (c) 2021-2026 Timothée David--Cléris <tim.shamrock@proton.me>
5// SPDX-License-Identifier: CeCILL Free Software License Agreement v2.1
6// Shamrock is licensed under the CeCILL 2.1 License, see LICENSE for more information
7//
8// -------------------------------------------------------//
9
10#pragma once
11
18
19#include "shambase/integer.hpp"
21#include "DigitBinner.hpp"
23#include "shamalgs/memory.hpp"
24#include "shamalgs/numeric.hpp"
25#include "shambackends/sycl.hpp"
27#include <numeric>
28
30
31 /*
32 tile histogram :
33
34 element a :
35
36 a = 2 + 1x4^3
37
38 |digit | digit places |
39 | | 0 | 1 | 2 | 3 |
40 ------------------------
41 | 0 | 0 | 1 | 1 | 0 |
42 | 1 | 0 | 0 | 0 | 1 |
43 | 2 | 1 | 0 | 0 | 0 |
44 | 3 | 0 | 0 | 0 | 0 |
45
46 sum array on the table
47
48 */
49
50 template<class Tkey, class Tval, u32 group_size, u32 digit_len>
52
53 template<class Tkey, class Tval, u32 group_size, u32 digit_len>
54 void sort_by_key_radix_onesweep(
55 sycl::queue &q, sycl::buffer<Tkey> &buf_key, sycl::buffer<Tval> &buf_values, u32 len) {
56
57 sycl::buffer<Tkey> tmp_buf_key(len);
58 sycl::buffer<Tval> tmp_buf_values(len);
59
60 auto get_in_keys = [&](u32 step) -> sycl::buffer<Tkey> & {
61 if (step % 2 == 0) {
62 return buf_key;
63 } else {
64 return tmp_buf_key;
65 }
66 };
67
68 auto get_out_keys = [&](u32 step) -> sycl::buffer<Tkey> & {
69 if (step % 2 == 0) {
70 return tmp_buf_key;
71 } else {
72 return buf_key;
73 }
74 };
75
76 auto get_in_vals = [&](u32 step) -> sycl::buffer<Tval> & {
77 if (step % 2 == 0) {
78 return buf_values;
79 } else {
80 return tmp_buf_values;
81 }
82 };
83
84 auto get_out_vals = [&](u32 step) -> sycl::buffer<Tval> & {
85 if (step % 2 == 0) {
86 return tmp_buf_values;
87 } else {
88 return buf_values;
89 }
90 };
91
92 u32 group_cnt = shambase::group_count(len, group_size);
93
94 // group_cnt = group_cnt + (group_cnt % 4);
95 u32 corrected_len = group_cnt * group_size;
96
97 // memory::print_buf(buf_key, len, 16, "{:4} ");
98
99 using Binner = DigitBinner<Tkey, digit_len>;
100
101 sycl::buffer<u32> digit_histogram
102 = Binner::template make_digit_histogram<group_size>(q, buf_key, len);
103
104 // logger::raw_ln("digit histogram");
105 // memory::print_buf(digit_histogram, Binner::value_count, Binner::digit_count, "{:4} ");
106
107 {
108
109 sycl::host_accessor acc{digit_histogram, sycl::read_write};
110
111 auto ptr = &(acc[0]);
112
113 for (u32 digit_place = 0; digit_place < Binner::digit_bit_places; digit_place++) {
114 u32 offset_ptr = Binner::digit_count * digit_place;
115 std::exclusive_scan(
116 ptr + offset_ptr, ptr + offset_ptr + Binner::digit_count, ptr + offset_ptr, 0);
117 }
118 }
119
120 // logger::raw_ln("digit histogram");
121 // memory::print_buf(digit_histogram, Binner::value_count, Binner::digit_count, "{:4} ");
122
123 using namespace shamalgs::numeric::details;
124
125 using DecoupledLookBack
126 = ScanDecoupledLoockBack<u32, group_size, Standard, ScanTile30bitint>;
127
128 u32 step = 0;
129 for (Tkey cur_digit_place = 0; cur_digit_place < shambase::bitsizeof<Tkey>;
130 cur_digit_place += digit_len) {
131
132 DecoupledLookBack dlookbackscan(q, group_cnt, Binner::digit_count);
133
134 atomic::DynamicIdGenerator<i32, group_size> id_gen(q);
135
136 q.submit([&, len, cur_digit_place, step](sycl::handler &cgh) {
137 sycl::accessor keys{get_in_keys(step), cgh, sycl::read_only};
138 sycl::accessor vals{get_in_vals(step), cgh, sycl::read_only};
139
140 sycl::accessor new_keys{get_out_keys(step), cgh, sycl::write_only, sycl::no_init};
141 sycl::accessor new_vals{get_out_vals(step), cgh, sycl::write_only, sycl::no_init};
142
143 sycl::accessor value_write_offsets{digit_histogram, cgh, sycl::read_only};
144
145 sycl::local_accessor<u32, 1> local_digit_counts{Binner::digit_count, cgh};
146 sycl::local_accessor<u32, 1> scanned_digit_counts{Binner::digit_count, cgh};
147
148 // sycl::stream dump (4096,1024,cgh);
149 auto dyn_id = id_gen.get_access(cgh);
150
151 auto scanop = dlookbackscan.get_access(cgh);
152
153 using at_ref_loc_count = sycl::atomic_ref<
154 u32,
155 sycl::memory_order_relaxed,
156 sycl::memory_scope_work_group,
157 sycl::access::address_space::local_space>;
158
159 u32 histogram_ptr_offset = step * Binner::digit_count;
160
162 sycl::nd_range<1>{corrected_len, group_size}, [=](sycl::nd_item<1> id) {
163 atomic::DynamicId<i32> group_id = dyn_id.compute_id(id);
164
165 u32 local_id = id.get_local_id(0);
166 u32 group_tile_id = group_id.dyn_group_id;
167 u32 global_id = group_id.dyn_global_id;
168
169 // u32 group_tile_id = id.get_group_linear_id();
170 // u32 global_id = group_tile_id * group_size + local_id;
171
172 if (local_id == 0) {
173 for (u32 digit_ptr = 0; digit_ptr < Binner::digit_count; digit_ptr++) {
174 local_digit_counts[digit_ptr] = 0;
175 }
176 }
177 id.barrier(sycl::access::fence_space::local_space);
178
179 bool is_valid_key = (global_id < len);
180
181 Tkey cur_key = (is_valid_key) ? keys[global_id] : 0;
182
183 Tkey digit_value = Binner::get_digit_value(cur_key, step);
184
185 // if(group_tile_id == 0){
186 // dump << local_digit_counts[0] << " " << local_digit_counts[1] <<
187 // "\n";
188 // }
189
190 u32 curr_loc_offset = at_ref_loc_count(local_digit_counts[digit_value])
191 .fetch_add((is_valid_key) ? 1U : 0);
192
193 // if(group_tile_id == 0){
194 // dump << cur_key << " " <<digit_value << " " << curr_loc_offset <<
195 // "\n";
196 // }
197 //
198 //
199 id.barrier(sycl::access::fence_space::local_space);
200 // if(group_tile_id == 0){
201 // dump << local_digit_counts[0] << " " << local_digit_counts[1] <<
202 // "\n";
203 // }
204
205 // generate scanned tile value for each digits
206 for (u32 digit_ptr = 0; digit_ptr < Binner::digit_count; digit_ptr++) {
207
208 scanop.decoupled_lookback_scan(
209 id,
210 local_id,
211 group_tile_id,
212 [=]() {
213 return local_digit_counts[digit_ptr];
214 },
215 [=](u32 accum) {
216 scanned_digit_counts[digit_ptr] = accum;
217 },
218 digit_ptr);
219 }
220
221 // if(local_id == 0){
222 // dump << "-- gid" << global_id << "\n";
223 // for(u32 digit_ptr = 0; digit_ptr < Binner::digit_count; digit_ptr
224 // ++){
225 // dump << local_digit_counts[digit_ptr] << " "
226 // <<scanned_digit_counts[digit_ptr] << "\n";
227 // }
228 // }
229
230 // load from global buffer
231 if (global_id < len) {
232
233 // logger::raw_ln(cur_key,digit_value,curr_loc_offset, step);
234
235 u32 value_write_offset_global
236 = value_write_offsets[(digit_value) + histogram_ptr_offset];
237
238 u32 write_offset = curr_loc_offset + scanned_digit_counts[digit_value]
239 + value_write_offset_global;
240
241 // if(local_id == 0){
242 // dump << "-- gid" << global_id << "\n";
243 // dump << "k="<<cur_key << "\n";
244 // dump << "d="<<digit_value << "\n";
245 // dump << "delta="<<curr_loc_offset << "\n";
246 // dump << "gdelta="<<value_write_offset_global << "\n";
247 // dump << "sdelta="<<scanned_digit_counts[digit_value] << "\n";
248 // dump << "wdelta="<<write_offset << "\n";
249 // }
250
251 new_keys[write_offset]
252 = keys[global_id]; // can be loaded initially and stored only here
253 // rather than reload
254 new_vals[write_offset] = vals[global_id];
255 }
256 });
257 });
258
259 // q.wait();
260
261 // logger::raw_ln("digit histogram place : ", cur_digit_place);
262 // memory::print_buf(get_out_keys(step), len, 16, "{:4} ");
263
264 // return;
265
266 step++;
267 }
268 }
269
270} // namespace shamalgs::algorithm::details
std::uint32_t u32
32 bit unsigned integer
namespace to store algorithms implemented by shamalgs
constexpr u32 group_count(u32 len, u32 group_size)
Calculates the number of groups based on the length and group size.
Definition integer.hpp:125
main include file for memory algorithms
Traits for C++ types.