Shamrock 2025.10.0
Astrophysical Code
Loading...
Searching...
No Matches
bitonicSort_updated.cpp
Go to the documentation of this file.
1// -------------------------------------------------------//
2//
3// SHAMROCK code for hydrodynamics
4// Copyright (c) 2021-2026 Timothée David--Cléris <tim.shamrock@proton.me>
5// SPDX-License-Identifier: CeCILL Free Software License Agreement v2.1
6// Shamrock is licensed under the CeCILL 2.1 License, see LICENSE for more information
7//
8// -------------------------------------------------------//
9
17#include "shambase/integer.hpp"
20
21// modified from http://www.bealto.com/gpu-sorting.html
22
24
25 template<class Tkey, class Tval>
27
28 inline static void _order(Tkey &a, Tkey &b, Tval &va, Tval &vb, bool reverse) {
29 bool swap = reverse ^ (a < b);
30 Tkey auxa = a;
31 Tkey auxb = b;
32 Tval auxida = va;
33 Tval auxidb = vb;
34 a = (swap) ? auxb : auxa;
35 b = (swap) ? auxa : auxb;
36 va = (swap) ? auxidb : auxida;
37 vb = (swap) ? auxida : auxidb;
38 }
39
40 inline static void _orderV(
41 Tkey *__restrict__ x, Tval *__restrict__ vx, u32 a, u32 b, bool reverse) {
42 bool swap = reverse ^ (x[a] < x[b]);
43 auto auxa = x[a];
44 auto auxb = x[b];
45 auto auxida = vx[a];
46 auto auxidb = vx[b];
47 x[a] = (swap) ? auxb : auxa;
48 x[b] = (swap) ? auxa : auxb;
49 vx[a] = (swap) ? auxidb : auxida;
50 vx[b] = (swap) ? auxida : auxidb;
51 }
52
53 template<u32 stencil_size>
54 static void order_stencil(Tkey *__restrict__ x, Tval *__restrict__ vx, u32 a, bool reverse);
55
56 template<>
57 inline void order_stencil<2>(
58 Tkey *__restrict__ x, Tval *__restrict__ vx, u32 a, bool reverse) {
59 _orderV(x, vx, a, a + 1, reverse);
60 }
61
62 template<>
63 inline void order_stencil<4>(
64 Tkey *__restrict__ x, Tval *__restrict__ vx, u32 a, bool reverse) {
65#pragma unroll
66 for (int i4 = 0; i4 < 2; i4++) {
67 _orderV(x, vx, a + i4, a + i4 + 2, reverse);
68 }
69 order_stencil<2>(x, vx, a, reverse);
70 order_stencil<2>(x, vx, a + 2, reverse);
71 }
72
73 template<>
74 inline void order_stencil<8>(
75 Tkey *__restrict__ x, Tval *__restrict__ vx, u32 a, bool reverse) {
76#pragma unroll
77 for (int i8 = 0; i8 < 4; i8++) {
78 _orderV(x, vx, a + i8, a + i8 + 4, reverse);
79 }
80 order_stencil<4>(x, vx, a, reverse);
81 order_stencil<4>(x, vx, a + 4, reverse);
82 }
83
84 template<>
85 inline void order_stencil<16>(
86 Tkey *__restrict__ x, Tval *__restrict__ vx, u32 a, bool reverse) {
87#pragma unroll
88 for (int i16 = 0; i16 < 8; i16++) {
89 _orderV(x, vx, a + i16, a + i16 + 8, reverse);
90 }
91 order_stencil<8>(x, vx, a, reverse);
92 order_stencil<8>(x, vx, a + 8, reverse);
93 }
94
95 template<>
96 inline void order_stencil<32>(
97 Tkey *__restrict__ x, Tval *__restrict__ vx, u32 a, bool reverse) {
98#pragma unroll
99 for (int i32 = 0; i32 < 16; i32++) {
100 _orderV(x, vx, a + i32, a + i32 + 16, reverse);
101 }
102 order_stencil<16>(x, vx, a, reverse);
103 order_stencil<16>(x, vx, a + 16, reverse);
104 }
105
106 template<u32 stencil_size>
107 static void order_kernel(
108 Tkey *__restrict__ m, Tval *__restrict__ id, u32 inc, u32 length, i32 t);
109
110 template<>
111 inline void order_kernel<32>(
112 Tkey *__restrict__ m, Tval *__restrict__ id, u32 inc, u32 length, i32 t) {
113 u32 _inc = inc;
114 u32 _dir = length << 1U;
115
116 _inc >>= 4;
117 int low = t & (_inc - 1); // low order bits (below INC)
118 int i = ((t - low) << 5) + low; // insert 000 at position INC
119 bool reverse = ((_dir & i) == 0); // asc/desc order
120
121 // Load
122 Tkey x[32];
123#pragma unroll
124 for (int k = 0; k < 32; k++)
125 x[k] = m[k * _inc + i];
126
127 uint idx[32];
128#pragma unroll
129 for (int k = 0; k < 32; k++)
130 idx[k] = id[k * _inc + i];
131
132 // Sort
133 order_stencil<32>(x, idx, 0, reverse);
134
135// Store
136#pragma unroll
137 for (int k = 0; k < 32; k++)
138 m[k * _inc + i] = x[k];
139#pragma unroll
140 for (int k = 0; k < 32; k++)
141 id[k * _inc + i] = idx[k];
142 }
143
144 template<>
145 inline void order_kernel<16>(
146 Tkey *__restrict__ m, Tval *__restrict__ id, u32 inc, u32 length, i32 t) {
147
148 u32 _inc = inc;
149 u32 _dir = length << 1;
150
151 _inc >>= 3;
152 int low = t & (_inc - 1); // low order bits (below INC)
153 int i = ((t - low) << 4) + low; // insert 000 at position INC
154 bool reverse = ((_dir & i) == 0); // asc/desc order
155
156 // Load
157 Tkey x[16];
158#pragma unroll
159 for (int k = 0; k < 16; k++)
160 x[k] = m[k * _inc + i];
161
162 Tval idx[16];
163#pragma unroll
164 for (int k = 0; k < 16; k++)
165 idx[k] = id[k * _inc + i];
166
167 // Sort
168 order_stencil<16>(x, idx, 0, reverse);
169
170// Store
171#pragma unroll
172 for (int k = 0; k < 16; k++)
173 m[k * _inc + i] = x[k];
174#pragma unroll
175 for (int k = 0; k < 16; k++)
176 id[k * _inc + i] = idx[k];
177 }
178
179 template<>
180 inline void order_kernel<8>(
181 Tkey *__restrict__ m, Tval *__restrict__ id, u32 inc, u32 length, i32 t) {
182 u32 _inc = inc;
183 u32 _dir = length << 1;
184
185 _inc >>= 2;
186 int low = t & (_inc - 1); // low order bits (below INC)
187 int i = ((t - low) << 3) + low; // insert 000 at position INC
188 bool reverse = ((_dir & i) == 0); // asc/desc order
189
190 // Load
191 Tkey x[8];
192#pragma unroll
193 for (int k = 0; k < 8; k++)
194 x[k] = m[k * _inc + i];
195
196 Tval idx[8];
197#pragma unroll
198 for (int k = 0; k < 8; k++)
199 idx[k] = id[k * _inc + i];
200
201 // Sort
202 order_stencil<8>(x, idx, 0, reverse);
203
204// Store
205#pragma unroll
206 for (int k = 0; k < 8; k++)
207 m[k * _inc + i] = x[k];
208#pragma unroll
209 for (int k = 0; k < 8; k++)
210 id[k * _inc + i] = idx[k];
211 }
212
213 template<>
214 inline void order_kernel<4>(
215 Tkey *__restrict__ m, Tval *__restrict__ id, u32 inc, u32 length, i32 t) {
216 u32 _inc = inc;
217 u32 _dir = length << 1;
218
219 _inc >>= 1;
220 int low = t & (_inc - 1); // low order bits (below INC)
221 int i = ((t - low) << 2) + low; // insert 00 at position INC
222 bool reverse = ((_dir & i) == 0); // asc/desc order
223
224 // Load
225 Tkey x0 = m[0 + i];
226 Tkey x1 = m[_inc + i];
227 Tkey x2 = m[2 * _inc + i];
228 Tkey x3 = m[3 * _inc + i];
229
230 Tval idx0 = id[0 + i];
231 Tval idx1 = id[_inc + i];
232 Tval idx2 = id[2 * _inc + i];
233 Tval idx3 = id[3 * _inc + i];
234
235 // Sort
236 _order(x0, x2, idx0, idx2, reverse);
237 _order(x1, x3, idx1, idx3, reverse);
238 _order(x0, x1, idx0, idx1, reverse);
239 _order(x2, x3, idx2, idx3, reverse);
240
241 // Store
242 m[0 + i] = x0;
243 m[_inc + i] = x1;
244 m[2 * _inc + i] = x2;
245 m[3 * _inc + i] = x3;
246
247 id[0 + i] = idx0;
248 id[_inc + i] = idx1;
249 id[2 * _inc + i] = idx2;
250 id[3 * _inc + i] = idx3;
251 }
252
253 template<>
254 inline void order_kernel<2>(
255 Tkey *__restrict__ m, Tval *__restrict__ id, u32 inc, u32 length, i32 t) {
256 u32 _inc = inc;
257 u32 _dir = length << 1;
258
259 int low = t & (_inc - 1); // low order bits (below INC)
260 int i = (t << 1) - low; // insert 0 at position INC
261 bool reverse = ((_dir & i) == 0); // asc/desc order
262
263 u32 addr_1 = 0 + i;
264 u32 addr_2 = _inc + i;
265
266 // Load
267 Tkey x0 = m[addr_1];
268 Tkey x1 = m[addr_2];
269 Tval idx0 = id[addr_1];
270 Tval idx1 = id[addr_2];
271
272 // Sort
273 _order(x0, x1, idx0, idx1, reverse);
274
275 // Store
276 m[addr_1] = x0;
277 m[addr_2] = x1;
278 id[addr_1] = idx0;
279 id[addr_2] = idx1;
280 }
281 };
282
283 template<class Tkey, class Tval, u32 MaxStencilSize>
284 void sort_by_key_bitonic_updated(
285 sycl::queue &q, sycl::buffer<Tkey> &buf_key, sycl::buffer<Tval> &buf_values, u32 len) {
286
287 if (!shambase::is_pow_of_two(len)) {
289 "this algorithm can only be used with length that are powers of two");
290 }
291
292 using B = OrderingPrimitive<Tkey, Tval>;
293
294 for (u32 length = 1; length < len; length <<= 1) {
295 u32 inc = length;
296 while (inc > 0) {
297 // log("inc : %d\n",inc);
298 // int ninc = 1;
299 u32 ninc = 0;
300
301 // B32 sort kernel is less performant than the B16 because of cache size
302 if constexpr (MaxStencilSize >= 32) {
303 if (inc >= 16 && ninc == 0) {
304 ninc = 5;
305 unsigned int nThreads = len >> ninc;
306 sycl::range<1> range{nThreads};
307
308 q.submit([&](sycl::handler &cgh) {
309 sycl::accessor accm{buf_key, cgh, sycl::read_write};
310 sycl::accessor accid{buf_values, cgh, sycl::read_write};
311
312 shambase::parallel_for(
313 cgh, nThreads, "bitonic sort pass B32", [=](u64 gid) {
314 //(__global data_t * data,__global uint * ids,int inc,int dir)
315
316 Tkey *m = &(accm[0]);
317 Tval *id = &(accid[0]);
318 B::template order_kernel<32>(m, id, inc, length, gid);
319 });
320 });
321 }
322 }
323
324 if constexpr (MaxStencilSize >= 16) {
325 if (inc >= 8 && ninc == 0) {
326 ninc = 4;
327 unsigned int nThreads = len >> ninc;
328 sycl::range<1> range{nThreads};
329
330 q.submit([&](sycl::handler &cgh) {
331 sycl::accessor accm{buf_key, cgh, sycl::read_write};
332 sycl::accessor accid{buf_values, cgh, sycl::read_write};
333
334 shambase::parallel_for(
335 cgh, nThreads, "bitonic sort pass B16", [=](u64 gid) {
336 //(__global data_t * data,__global uint * ids,int inc,int dir)
337
338 Tkey *m = &(accm[0]);
339 Tval *id = &(accid[0]);
340 B::template order_kernel<16>(m, id, inc, length, gid);
341 });
342 });
343
344 // sort_kernel_B8(arg_eq,* buf_key->buf,*
345 // particles::buf_ids->buf,inc,length<<1);//.wait();
346 }
347 }
348
349 if constexpr (MaxStencilSize >= 8) {
350 // B8
351 if (inc >= 4 && ninc == 0) {
352 ninc = 3;
353 unsigned int nThreads = len >> ninc;
354 sycl::range<1> range{nThreads};
355
356 q.submit([&](sycl::handler &cgh) {
357 sycl::accessor accm{buf_key, cgh, sycl::read_write};
358 sycl::accessor accid{buf_values, cgh, sycl::read_write};
359
360 shambase::parallel_for(
361 cgh, nThreads, "bitonic sort pass B8", [=](u64 gid) {
362 //(__global data_t * data,__global uint * ids,int inc,int dir)
363
364 Tkey *m = &(accm[0]);
365 Tval *id = &(accid[0]);
366 B::template order_kernel<8>(m, id, inc, length, gid);
367 });
368 });
369
370 // sort_kernel_B8(arg_eq,* buf_key->buf,*
371 // particles::buf_ids->buf,inc,length<<1);//.wait();
372 }
373 }
374
375 if constexpr (MaxStencilSize >= 4) {
376 // B4
377 if (inc >= 2 && ninc == 0) {
378 ninc = 2;
379 unsigned int nThreads = len >> ninc;
380 sycl::range<1> range{nThreads};
381 // sort_kernel_B4(arg_eq,* buf_key->buf,*
382 // particles::buf_ids->buf,inc,length<<1);
383
384 q.submit([&](sycl::handler &cgh) {
385 sycl::accessor accm{buf_key, cgh, sycl::read_write};
386 sycl::accessor accid{buf_values, cgh, sycl::read_write};
387
388 shambase::parallel_for(
389 cgh, nThreads, "bitonic sort pass B4", [=](u64 gid) {
390 Tkey *m = &(accm[0]);
391 Tval *id = &(accid[0]);
392 B::template order_kernel<4>(m, id, inc, length, gid);
393 });
394 });
395 }
396 }
397
398 // B2
399 if (ninc == 0) {
400 ninc = 1;
401 unsigned int nThreads = len >> ninc;
402 sycl::range<1> range{nThreads};
403 // sort_kernel_B2(arg_eq,* buf_key->buf,*
404 // particles::buf_ids->buf,inc,length<<1);
405
406 q.submit([&](sycl::handler &cgh) {
407 sycl::accessor accm{buf_key, cgh, sycl::read_write};
408 sycl::accessor accid{buf_values, cgh, sycl::read_write};
409
410 shambase::parallel_for(cgh, nThreads, "bitonic sort pass B2", [=](u64 gid) {
411 //(__global data_t * data,__global uint * ids,int inc,int dir)
412
413 Tkey *m = &(accm[0]);
414 Tval *id = &(accid[0]);
415
416 B::template order_kernel<2>(m, id, inc, length, gid);
417 });
418 });
419 }
420
421 inc >>= ninc;
422 }
423 }
424 }
425
426 template void sort_by_key_bitonic_updated<u32, u32, 16>(
427 sycl::queue &q, sycl::buffer<u32> &buf_key, sycl::buffer<u32> &buf_values, u32 len);
428
429 template void sort_by_key_bitonic_updated<u64, u32, 16>(
430 sycl::queue &q, sycl::buffer<u64> &buf_key, sycl::buffer<u32> &buf_values, u32 len);
431
432 template void sort_by_key_bitonic_updated<u32, u32, 8>(
433 sycl::queue &q, sycl::buffer<u32> &buf_key, sycl::buffer<u32> &buf_values, u32 len);
434
435 template void sort_by_key_bitonic_updated<u64, u32, 8>(
436 sycl::queue &q, sycl::buffer<u64> &buf_key, sycl::buffer<u32> &buf_values, u32 len);
437
438 template void sort_by_key_bitonic_updated<u32, u32, 32>(
439 sycl::queue &q, sycl::buffer<u32> &buf_key, sycl::buffer<u32> &buf_values, u32 len);
440
441 template void sort_by_key_bitonic_updated<u64, u32, 32>(
442 sycl::queue &q, sycl::buffer<u64> &buf_key, sycl::buffer<u32> &buf_values, u32 len);
443
444} // namespace shamalgs::algorithm::details
std::int8_t i8
8 bit integer
std::uint32_t u32
32 bit unsigned integer
std::uint64_t u64
64 bit unsigned integer
std::int16_t i16
16 bit integer
std::int32_t i32
32 bit integer
namespace to store algorithms implemented by shamalgs
constexpr bool is_pow_of_two(T v) noexcept
determine if v is a power of two and check if v==0 Source : https://graphics.stanford....
Definition integer.hpp:49
void throw_with_loc(std::string message, SourceLocation loc=SourceLocation{})
Throw an exception and append the source location to it.