Shamrock 2025.10.0
Astrophysical Code
Loading...
Searching...
No Matches
bitonicSort_legacy.cpp
Go to the documentation of this file.
1// -------------------------------------------------------//
2//
3// SHAMROCK code for hydrodynamics
4// Copyright (c) 2021-2026 Timothée David--Cléris <tim.shamrock@proton.me>
5// SPDX-License-Identifier: CeCILL Free Software License Agreement v2.1
6// Shamrock is licensed under the CeCILL 2.1 License, see LICENSE for more information
7//
8// -------------------------------------------------------//
9
18#include "shambase/integer.hpp"
20#include "shamcomm/logs.hpp"
21#include <stdexcept>
22
23// modified from http://www.bealto.com/gpu-sorting.html
24
25#define MAXORDER_SORT_KERNEL 16
26
27#define ORDER(a, b, ida, idb) \
28 { \
29 bool swap = reverse ^ (a < b); \
30 Tkey auxa = a; \
31 Tkey auxb = b; \
32 Tval auxida = ida; \
33 Tval auxidb = idb; \
34 a = (swap) ? auxb : auxa; \
35 b = (swap) ? auxa : auxb; \
36 ida = (swap) ? auxidb : auxida; \
37 idb = (swap) ? auxida : auxidb; \
38 }
39
40#define ORDERV(x, idx, a, b) \
41 { \
42 bool swap = reverse ^ (x[a] < x[b]); \
43 Tkey auxa = x[a]; \
44 Tkey auxb = x[b]; \
45 Tval auxida = idx[a]; \
46 Tval auxidb = idx[b]; \
47 x[a] = (swap) ? auxb : auxa; \
48 x[b] = (swap) ? auxa : auxb; \
49 idx[a] = (swap) ? auxidb : auxida; \
50 idx[b] = (swap) ? auxida : auxidb; \
51 }
52
53#define B2V(x, idx, a) {ORDERV(x, idx, a, a + 1)}
54
55#define B4V(x, idx, a) \
56 { \
57 for (int i4 = 0; i4 < 2; i4++) { \
58 ORDERV(x, idx, a + i4, a + i4 + 2) \
59 } \
60 B2V(x, idx, a) B2V(x, idx, a + 2) \
61 }
62
63#define B8V(x, idx, a) \
64 { \
65 for (int i8 = 0; i8 < 4; i8++) { \
66 ORDERV(x, idx, a + i8, a + i8 + 4) \
67 } \
68 B4V(x, idx, a) B4V(x, idx, a + 4) \
69 }
70
71#define B16V(x, idx, a) \
72 { \
73 for (int i16 = 0; i16 < 8; i16++) { \
74 ORDERV(x, idx, a + i16, a + i16 + 8) \
75 } \
76 B8V(x, idx, a) B8V(x, idx, a + 8) \
77 }
78
79#define B32V(x, idx, a) \
80 { \
81 for (int i32 = 0; i32 < 16; i32++) { \
82 ORDERV(x, idx, a + i32, a + i32 + 16) \
83 } \
84 B16V(x, idx, a) B16V(x, idx, a + 16) \
85 }
86
87class Bitonic_sort_B32_morton32;
88class Bitonic_sort_B16_morton32;
89class Bitonic_sort_B8_morton32;
90class Bitonic_sort_B4_morton32;
91class Bitonic_sort_B2_morton32;
92
93class Bitonic_sort_B32_morton64;
94class Bitonic_sort_B16_morton64;
95class Bitonic_sort_B8_morton64;
96class Bitonic_sort_B4_morton64;
97class Bitonic_sort_B2_morton64;
98
100
101 template<class Tkey, class Tval>
102 void sort_by_key_bitonic_legacy(
103 sycl::queue &q, sycl::buffer<Tkey> &buf_key, sycl::buffer<Tval> &buf_values, u32 len) {
104
105 if (!shambase::is_pow_of_two(len)) {
107 "this algorithm can only be used with length that are powers of two");
108 }
109
111 "BitonicSorter", "submit : sycl_sort_morton_key_pair<u32, MultiKernel>");
112
113 for (u32 length = 1; length < len; length <<= 1) {
114 u32 inc = length;
115 while (inc > 0) {
116 // log("inc : %d\n",inc);
117 // int ninc = 1;
118 u32 ninc = 0;
119
120// B32 sort kernel is less performant than the B16 because of cache size
121#if MAXORDER_SORT_KERNEL >= 32
122 if (inc >= 16 && ninc == 0) {
123 ninc = 5;
124 unsigned int nThreads = len >> ninc;
125 sycl::range<1> range{nThreads};
126
127 auto ker_sort_morton_b32 = [&](sycl::handler &cgh) {
128 sycl::accessor m{buf_key, cgh, sycl::read_write};
129 sycl::accessor id{buf_values, cgh, sycl::read_write};
130
131 cgh.parallel_for(range, [=](sycl::item<1> item) {
132 //(__global data_t * data,__global uint * ids,int inc,int dir)
133
134 u32 _inc = inc;
135 u32 _dir = length << 1;
136
137 _inc >>= 4;
138 int t = item.get_id(); // thread index
139 int low = t & (_inc - 1); // low order bits (below INC)
140 int i = ((t - low) << 5) + low; // insert 000 at position INC
141 bool reverse = ((_dir & i) == 0); // asc/desc order
142
143 // Load
144 Tkey x[32];
145 for (int k = 0; k < 32; k++)
146 x[k] = m[k * _inc + i];
147
148 uint idx[32];
149 for (int k = 0; k < 32; k++)
150 idx[k] = id[k * _inc + i];
151
152 // Sort
153 B32V(x, idx, 0)
154
155 // Store
156 for (int k = 0; k < 32; k++)
157 m[k * _inc + i] = x[k];
158 for (int k = 0; k < 32; k++)
159 id[k * _inc + i] = idx[k];
160 });
161 };
162 q.submit(ker_sort_morton_b32);
163 }
164#endif
165
166#if MAXORDER_SORT_KERNEL >= 16
167 if (inc >= 8 && ninc == 0) {
168 ninc = 4;
169 unsigned int nThreads = len >> ninc;
170 sycl::range<1> range{nThreads};
171
172 auto ker_sort_morton_b16 = [&](sycl::handler &cgh) {
173 sycl::accessor m{buf_key, cgh, sycl::read_write};
174 sycl::accessor id{buf_values, cgh, sycl::read_write};
175
176 cgh.parallel_for(range, [=](sycl::item<1> item) {
177 //(__global data_t * data,__global uint * ids,int inc,int dir)
178
179 u32 _inc = inc;
180 u32 _dir = length << 1;
181
182 _inc >>= 3;
183 int t = item.get_id(0); // thread index
184 int low = t & (_inc - 1); // low order bits (below INC)
185 int i = ((t - low) << 4) + low; // insert 000 at position INC
186 bool reverse = ((_dir & i) == 0); // asc/desc order
187
188 // Load
189 Tkey x[16];
190 for (int k = 0; k < 16; k++)
191 x[k] = m[k * _inc + i];
192
193 Tval idx[16];
194 for (int k = 0; k < 16; k++)
195 idx[k] = id[k * _inc + i];
196
197 // Sort
198 B16V(x, idx, 0)
199
200 // Store
201 for (int k = 0; k < 16; k++)
202 m[k * _inc + i] = x[k];
203 for (int k = 0; k < 16; k++)
204 id[k * _inc + i] = idx[k];
205 });
206 };
207 q.submit(ker_sort_morton_b16);
208
209 // sort_kernel_B8(arg_eq,* buf_key->buf,*
210 // particles::buf_ids->buf,inc,length<<1);//.wait();
211 }
212#endif
213
214#if MAXORDER_SORT_KERNEL >= 8
215 // B8
216 if (inc >= 4 && ninc == 0) {
217 ninc = 3;
218 unsigned int nThreads = len >> ninc;
219 sycl::range<1> range{nThreads};
220
221 auto ker_sort_morton_b8 = [&](sycl::handler &cgh) {
222 sycl::accessor m{buf_key, cgh, sycl::read_write};
223 sycl::accessor id{buf_values, cgh, sycl::read_write};
224
225 cgh.parallel_for(range, [=](sycl::item<1> item) {
226 //(__global data_t * data,__global uint * ids,int inc,int dir)
227
228 u32 _inc = inc;
229 u32 _dir = length << 1;
230
231 _inc >>= 2;
232 int t = item.get_id(0); // thread index
233 int low = t & (_inc - 1); // low order bits (below INC)
234 int i = ((t - low) << 3) + low; // insert 000 at position INC
235 bool reverse = ((_dir & i) == 0); // asc/desc order
236
237 // Load
238 Tkey x[8];
239 for (int k = 0; k < 8; k++)
240 x[k] = m[k * _inc + i];
241
242 Tval idx[8];
243 for (int k = 0; k < 8; k++)
244 idx[k] = id[k * _inc + i];
245
246 // Sort
247 B8V(x, idx, 0)
248
249 // Store
250 for (int k = 0; k < 8; k++)
251 m[k * _inc + i] = x[k];
252 for (int k = 0; k < 8; k++)
253 id[k * _inc + i] = idx[k];
254 });
255 };
256 q.submit(ker_sort_morton_b8);
257
258 // sort_kernel_B8(arg_eq,* buf_key->buf,*
259 // particles::buf_ids->buf,inc,length<<1);//.wait();
260 }
261#endif
262
263#if MAXORDER_SORT_KERNEL >= 4
264 // B4
265 if (inc >= 2 && ninc == 0) {
266 ninc = 2;
267 unsigned int nThreads = len >> ninc;
268 sycl::range<1> range{nThreads};
269 // sort_kernel_B4(arg_eq,* buf_key->buf,*
270 // particles::buf_ids->buf,inc,length<<1);
271 auto ker_sort_morton_b4 = [&](sycl::handler &cgh) {
272 sycl::accessor m{buf_key, cgh, sycl::read_write};
273 sycl::accessor id{buf_values, cgh, sycl::read_write};
274 cgh.parallel_for(range, [=](sycl::item<1> item) {
275 //(__global data_t * data,__global uint * ids,int inc,int dir)
276
277 u32 _inc = inc;
278 u32 _dir = length << 1;
279
280 _inc >>= 1;
281 int t = item.get_id(0); // thread index
282 int low = t & (_inc - 1); // low order bits (below INC)
283 int i = ((t - low) << 2) + low; // insert 00 at position INC
284 bool reverse = ((_dir & i) == 0); // asc/desc order
285
286 // Load
287 Tkey x0 = m[0 + i];
288 Tkey x1 = m[_inc + i];
289 Tkey x2 = m[2 * _inc + i];
290 Tkey x3 = m[3 * _inc + i];
291
292 Tval idx0 = id[0 + i];
293 Tval idx1 = id[_inc + i];
294 Tval idx2 = id[2 * _inc + i];
295 Tval idx3 = id[3 * _inc + i];
296
297 // Sort
298 ORDER(x0, x2, idx0, idx2)
299 ORDER(x1, x3, idx1, idx3)
300 ORDER(x0, x1, idx0, idx1)
301 ORDER(x2, x3, idx2, idx3)
302
303 // Store
304 m[0 + i] = x0;
305 m[_inc + i] = x1;
306 m[2 * _inc + i] = x2;
307 m[3 * _inc + i] = x3;
308
309 id[0 + i] = idx0;
310 id[_inc + i] = idx1;
311 id[2 * _inc + i] = idx2;
312 id[3 * _inc + i] = idx3;
313 });
314 };
315 q.submit(ker_sort_morton_b4);
316 }
317#endif
318
319 // B2
320 if (ninc == 0) {
321 ninc = 1;
322 unsigned int nThreads = len >> ninc;
323 sycl::range<1> range{nThreads};
324 // sort_kernel_B2(arg_eq,* buf_key->buf,*
325 // particles::buf_ids->buf,inc,length<<1);
326 auto ker_sort_morton_b2 = [&](sycl::handler &cgh) {
327 sycl::accessor m{buf_key, cgh, sycl::read_write};
328 sycl::accessor id{buf_values, cgh, sycl::read_write};
329
330 cgh.parallel_for(range, [=](sycl::item<1> item) {
331 //(__global data_t * data,__global uint * ids,int inc,int dir)
332
333 u32 _inc = inc;
334 u32 _dir = length << 1;
335
336 int t = item.get_id(0); // thread index
337 int low = t & (_inc - 1); // low order bits (below INC)
338 int i = (t << 1) - low; // insert 0 at position INC
339 bool reverse = ((_dir & i) == 0); // asc/desc order
340
341 // Load
342 Tkey x0 = m[0 + i];
343 Tkey x1 = m[_inc + i];
344 Tval idx0 = id[0 + i];
345 Tval idx1 = id[_inc + i];
346
347 // Sort
348 ORDER(x0, x1, idx0, idx1)
349
350 // Store
351 m[0 + i] = x0;
352 m[_inc + i] = x1;
353 id[0 + i] = idx0;
354 id[_inc + i] = idx1;
355 });
356 };
357 q.submit(ker_sort_morton_b2);
358 }
359
360 inc >>= ninc;
361 }
362 }
363 }
364
365 template void sort_by_key_bitonic_legacy(
366 sycl::queue &q, sycl::buffer<u32> &buf_key, sycl::buffer<u32> &buf_values, u32 len);
367
368 template void sort_by_key_bitonic_legacy(
369 sycl::queue &q, sycl::buffer<u64> &buf_key, sycl::buffer<u32> &buf_values, u32 len);
370
371} // namespace shamalgs::algorithm::details
constexpr const char * uint
Specific internal energy u.
std::uint32_t u32
32 bit unsigned integer
This header file contains utility functions related to exception handling in the code.
namespace to store algorithms implemented by shamalgs
constexpr bool is_pow_of_two(T v) noexcept
determine if v is a power of two and check if v==0 Source : https://graphics.stanford....
Definition integer.hpp:49
void throw_with_loc(std::string message, SourceLocation loc=SourceLocation{})
Throw an exception and append the source location to it.
void debug_sycl_ln(std::string module_name, Types... var2)
Prints a log message with multiple arguments followed by a newline.
Definition logs.hpp:133