Shamrock 2025.10.0
Astrophysical Code
Loading...
Searching...
No Matches
radixSortOnesweep.hpp
Go to the documentation of this file.
1// -------------------------------------------------------//
2//
3// SHAMROCK code for hydrodynamics
4// Copyright (c) 2021-2026 Timothée David--Cléris <tim.shamrock@proton.me>
5// SPDX-License-Identifier: CeCILL Free Software License Agreement v2.1
6// Shamrock is licensed under the CeCILL 2.1 License, see LICENSE for more information
7//
8// -------------------------------------------------------//
9
10#pragma once
11
19#include "shambase/integer.hpp"
21#include "DigitBinner.hpp"
23#include "shamalgs/memory.hpp"
24#include "shamalgs/numeric.hpp"
25#include "shambackends/sycl.hpp"
27#include <numeric>
28
30
31 /*
32 tile histogram :
33
34 element a :
35
36 a = 2 + 1x4^3
37
38 |digit | digit places |
39 | | 0 | 1 | 2 | 3 |
40 ------------------------
41 | 0 | 0 | 1 | 1 | 0 |
42 | 1 | 0 | 0 | 0 | 1 |
43 | 2 | 1 | 0 | 0 | 0 |
44 | 3 | 0 | 0 | 0 | 0 |
45
46 sum array on the table
47
48 */
49
50 template<class Tkey, class Tval, u32 group_size, u32 digit_len>
52
53 template<class Tkey, class Tval, u32 group_size, u32 digit_len>
54 void sort_by_key_radix_onesweep(
55 sycl::queue &q, sycl::buffer<Tkey> &buf_key, sycl::buffer<Tval> &buf_values, u32 len) {
56
57 sycl::buffer<Tkey> tmp_buf_key(len);
58 sycl::buffer<Tval> tmp_buf_values(len);
59
60 auto get_in_keys = [&](u32 step) -> sycl::buffer<Tkey> & {
61 if (step % 2 == 0) {
62 return buf_key;
63 } else {
64 return tmp_buf_key;
65 }
66 };
67
68 auto get_out_keys = [&](u32 step) -> sycl::buffer<Tkey> & {
69 if (step % 2 == 0) {
70 return tmp_buf_key;
71 } else {
72 return buf_key;
73 }
74 };
75
76 auto get_in_vals = [&](u32 step) -> sycl::buffer<Tval> & {
77 if (step % 2 == 0) {
78 return buf_values;
79 } else {
80 return tmp_buf_values;
81 }
82 };
83
84 auto get_out_vals = [&](u32 step) -> sycl::buffer<Tval> & {
85 if (step % 2 == 0) {
86 return tmp_buf_values;
87 } else {
88 return buf_values;
89 }
90 };
91
92 u32 group_cnt = shambase::group_count(len, group_size);
93
94 // group_cnt = group_cnt + (group_cnt % 4);
95 u32 corrected_len = group_cnt * group_size;
96
97 // memory::print_buf(buf_key, len, 16, "{:4} ");
98
99 using Binner = DigitBinner<Tkey, digit_len>;
100
101 sycl::buffer<u32> digit_histogram
102 = Binner::template make_digit_histogram<group_size>(q, buf_key, len);
103
104 // logger::raw_ln("digit histogram");
105 // memory::print_buf(digit_histogram, Binner::value_count, Binner::digit_count, "{:4} ");
106
107 {
108
109 sycl::host_accessor acc{digit_histogram, sycl::read_write};
110
111 auto ptr = &(acc[0]);
112
113 for (u32 digit_place = 0; digit_place < Binner::digit_bit_places; digit_place++) {
114 u32 offset_ptr = Binner::digit_count * digit_place;
115 std::exclusive_scan(
116 ptr + offset_ptr, ptr + offset_ptr + Binner::digit_count, ptr + offset_ptr, 0);
117 }
118 }
119
120 // logger::raw_ln("digit histogram");
121 // memory::print_buf(digit_histogram, Binner::value_count, Binner::digit_count, "{:4} ");
122
123 using namespace shamalgs::numeric::details;
124
125 using DecoupledLookBack
127
128 u32 step = 0;
129 for (Tkey cur_digit_place = 0; cur_digit_place < shambase::bitsizeof<Tkey>;
130 cur_digit_place += digit_len) {
131
132 DecoupledLookBack dlookbackscan(q, group_cnt, Binner::digit_count);
133
135
136 q.submit([&, len, cur_digit_place, step](sycl::handler &cgh) {
137 sycl::accessor keys{get_in_keys(step), cgh, sycl::read_only};
138 sycl::accessor vals{get_in_vals(step), cgh, sycl::read_only};
139
140 sycl::accessor new_keys{get_out_keys(step), cgh, sycl::write_only, sycl::no_init};
141 sycl::accessor new_vals{get_out_vals(step), cgh, sycl::write_only, sycl::no_init};
142
143 sycl::accessor value_write_offsets{digit_histogram, cgh, sycl::read_only};
144
145 sycl::local_accessor<u32, 1> local_digit_counts{Binner::digit_count, cgh};
146 sycl::local_accessor<u32, 1> scanned_digit_counts{Binner::digit_count, cgh};
147
148 // sycl::stream dump (4096,1024,cgh);
149 auto dyn_id = id_gen.get_access(cgh);
150
151 auto scanop = dlookbackscan.get_access(cgh);
152
153 using at_ref_loc_count = sycl::atomic_ref<
154 u32,
155 sycl::memory_order_relaxed,
156 sycl::memory_scope_work_group,
157 sycl::access::address_space::local_space>;
158
159 u32 histogram_ptr_offset = step * Binner::digit_count;
160
161 cgh.parallel_for<SortByKeyRadixOnesweep<Tkey, Tval, group_size, digit_len>>(
162 sycl::nd_range<1>{corrected_len, group_size}, [=](sycl::nd_item<1> id) {
163 atomic::DynamicId<i32> group_id = dyn_id.compute_id(id);
164
165 u32 local_id = id.get_local_id(0);
166 u32 group_tile_id = group_id.dyn_group_id;
167 u32 global_id = group_id.dyn_global_id;
168
169 // u32 group_tile_id = id.get_group_linear_id();
170 // u32 global_id = group_tile_id * group_size + local_id;
171
172 if (local_id == 0) {
173 for (u32 digit_ptr = 0; digit_ptr < Binner::digit_count; digit_ptr++) {
174 local_digit_counts[digit_ptr] = 0;
175 }
176 }
177 id.barrier(sycl::access::fence_space::local_space);
178
179 bool is_valid_key = (global_id < len);
180
181 Tkey cur_key = (is_valid_key) ? keys[global_id] : 0;
182
183 Tkey digit_value = Binner::get_digit_value(cur_key, step);
184
185 // if(group_tile_id == 0){
186 // dump << local_digit_counts[0] << " " << local_digit_counts[1] <<
187 // "\n";
188 // }
189
190 u32 curr_loc_offset = at_ref_loc_count(local_digit_counts[digit_value])
191 .fetch_add((is_valid_key) ? 1U : 0);
192
193 // if(group_tile_id == 0){
194 // dump << cur_key << " " <<digit_value << " " << curr_loc_offset <<
195 // "\n";
196 // }
197 //
198 //
199 id.barrier(sycl::access::fence_space::local_space);
200 // if(group_tile_id == 0){
201 // dump << local_digit_counts[0] << " " << local_digit_counts[1] <<
202 // "\n";
203 // }
204
205 // generate scanned tile value for each digits
206 for (u32 digit_ptr = 0; digit_ptr < Binner::digit_count; digit_ptr++) {
207
208 scanop.decoupled_lookback_scan(
209 id,
210 local_id,
211 group_tile_id,
212 [=]() {
213 return local_digit_counts[digit_ptr];
214 },
215 [=](u32 accum) {
216 scanned_digit_counts[digit_ptr] = accum;
217 },
218 digit_ptr);
219 }
220
221 // if(local_id == 0){
222 // dump << "-- gid" << global_id << "\n";
223 // for(u32 digit_ptr = 0; digit_ptr < Binner::digit_count; digit_ptr
224 // ++){
225 // dump << local_digit_counts[digit_ptr] << " "
226 // <<scanned_digit_counts[digit_ptr] << "\n";
227 // }
228 // }
229
230 // load from global buffer
231 if (global_id < len) {
232
233 // logger::raw_ln(cur_key,digit_value,curr_loc_offset, step);
234
235 u32 value_write_offset_global
236 = value_write_offsets[(digit_value) + histogram_ptr_offset];
237
238 u32 write_offset = curr_loc_offset + scanned_digit_counts[digit_value]
239 + value_write_offset_global;
240
241 // if(local_id == 0){
242 // dump << "-- gid" << global_id << "\n";
243 // dump << "k="<<cur_key << "\n";
244 // dump << "d="<<digit_value << "\n";
245 // dump << "delta="<<curr_loc_offset << "\n";
246 // dump << "gdelta="<<value_write_offset_global << "\n";
247 // dump << "sdelta="<<scanned_digit_counts[digit_value] << "\n";
248 // dump << "wdelta="<<write_offset << "\n";
249 // }
250
251 new_keys[write_offset]
252 = keys[global_id]; // can be loaded initially and stored only here
253 // rather than reload
254 new_vals[write_offset] = vals[global_id];
255 }
256 });
257 });
258
259 // q.wait();
260
261 // logger::raw_ln("digit histogram place : ", cur_digit_place);
262 // memory::print_buf(get_out_keys(step), len, 16, "{:4} ");
263
264 // return;
265
266 step++;
267 }
268 }
269
270} // namespace shamalgs::algorithm::details
std::uint32_t u32
32 bit unsigned integer
Sycl utility to dynamically generate group ids.
Object returned by DynamicIdGenerator containing information about the worker affected id.
namespace to store algorithms implemented by shamalgs
constexpr u32 group_count(u32 len, u32 group_size)
Calculates the number of groups based on the length and group size.
Definition integer.hpp:125
main include file for memory algorithms
Traits for C++ types.