Shamrock 2025.10.0
Astrophysical Code
Loading...
Searching...
No Matches
bitonicSort_updated_usm.cpp
Go to the documentation of this file.
1// -------------------------------------------------------//
2//
3// SHAMROCK code for hydrodynamics
4// Copyright (c) 2021-2026 Timothée David--Cléris <tim.shamrock@proton.me>
5// SPDX-License-Identifier: CeCILL Free Software License Agreement v2.1
6// Shamrock is licensed under the CeCILL 2.1 License, see LICENSE for more information
7//
8// -------------------------------------------------------//
9
17#include "shambase/integer.hpp"
21
22// modified from http://www.bealto.com/gpu-sorting.html
23
25
26 template<class Tkey, class Tval>
27 struct OrderingPrimitive {
28
29 inline static void _order(Tkey &a, Tkey &b, Tval &va, Tval &vb, bool reverse) {
30 bool swap = reverse ^ (a < b);
31 Tkey auxa = a;
32 Tkey auxb = b;
33 Tval auxida = va;
34 Tval auxidb = vb;
35 a = (swap) ? auxb : auxa;
36 b = (swap) ? auxa : auxb;
37 va = (swap) ? auxidb : auxida;
38 vb = (swap) ? auxida : auxidb;
39 }
40
41 inline static void _orderV(
42 Tkey *__restrict__ x, Tval *__restrict__ vx, u32 a, u32 b, bool reverse) {
43 bool swap = reverse ^ (x[a] < x[b]);
44 auto auxa = x[a];
45 auto auxb = x[b];
46 auto auxida = vx[a];
47 auto auxidb = vx[b];
48 x[a] = (swap) ? auxb : auxa;
49 x[b] = (swap) ? auxa : auxb;
50 vx[a] = (swap) ? auxidb : auxida;
51 vx[b] = (swap) ? auxida : auxidb;
52 }
53
54 template<u32 stencil_size>
55 static void order_stencil(Tkey *__restrict__ x, Tval *__restrict__ vx, u32 a, bool reverse);
56
57 template<>
58 inline void order_stencil<2>(
59 Tkey *__restrict__ x, Tval *__restrict__ vx, u32 a, bool reverse) {
60 _orderV(x, vx, a, a + 1, reverse);
61 }
62
63 template<>
64 inline void order_stencil<4>(
65 Tkey *__restrict__ x, Tval *__restrict__ vx, u32 a, bool reverse) {
66#pragma unroll
67 for (int i4 = 0; i4 < 2; i4++) {
68 _orderV(x, vx, a + i4, a + i4 + 2, reverse);
69 }
70 order_stencil<2>(x, vx, a, reverse);
71 order_stencil<2>(x, vx, a + 2, reverse);
72 }
73
74 template<>
75 inline void order_stencil<8>(
76 Tkey *__restrict__ x, Tval *__restrict__ vx, u32 a, bool reverse) {
77#pragma unroll
78 for (int i8 = 0; i8 < 4; i8++) {
79 _orderV(x, vx, a + i8, a + i8 + 4, reverse);
80 }
81 order_stencil<4>(x, vx, a, reverse);
82 order_stencil<4>(x, vx, a + 4, reverse);
83 }
84
85 template<>
86 inline void order_stencil<16>(
87 Tkey *__restrict__ x, Tval *__restrict__ vx, u32 a, bool reverse) {
88#pragma unroll
89 for (int i16 = 0; i16 < 8; i16++) {
90 _orderV(x, vx, a + i16, a + i16 + 8, reverse);
91 }
92 order_stencil<8>(x, vx, a, reverse);
93 order_stencil<8>(x, vx, a + 8, reverse);
94 }
95
96 template<>
97 inline void order_stencil<32>(
98 Tkey *__restrict__ x, Tval *__restrict__ vx, u32 a, bool reverse) {
99#pragma unroll
100 for (int i32 = 0; i32 < 16; i32++) {
101 _orderV(x, vx, a + i32, a + i32 + 16, reverse);
102 }
103 order_stencil<16>(x, vx, a, reverse);
104 order_stencil<16>(x, vx, a + 16, reverse);
105 }
106
107 template<u32 stencil_size>
108 static void order_kernel(
109 Tkey *__restrict__ m, Tval *__restrict__ id, u32 inc, u32 length, i32 t);
110
111 template<>
112 inline void order_kernel<32>(
113 Tkey *__restrict__ m, Tval *__restrict__ id, u32 inc, u32 length, i32 t) {
114 u32 _inc = inc;
115 u32 _dir = length << 1U;
116
117 _inc >>= 4;
118 int low = t & (_inc - 1); // low order bits (below INC)
119 int i = ((t - low) << 5) + low; // insert 000 at position INC
120 bool reverse = ((_dir & i) == 0); // asc/desc order
121
122 // Load
123 Tkey x[32];
124#pragma unroll
125 for (int k = 0; k < 32; k++)
126 x[k] = m[k * _inc + i];
127
128 Tval idx[32];
129#pragma unroll
130 for (int k = 0; k < 32; k++)
131 idx[k] = id[k * _inc + i];
132
133 // Sort
134 order_stencil<32>(x, idx, 0, reverse);
135
136// Store
137#pragma unroll
138 for (int k = 0; k < 32; k++)
139 m[k * _inc + i] = x[k];
140#pragma unroll
141 for (int k = 0; k < 32; k++)
142 id[k * _inc + i] = idx[k];
143 }
144
145 template<>
146 inline void order_kernel<16>(
147 Tkey *__restrict__ m, Tval *__restrict__ id, u32 inc, u32 length, i32 t) {
148
149 u32 _inc = inc;
150 u32 _dir = length << 1;
151
152 _inc >>= 3;
153 int low = t & (_inc - 1); // low order bits (below INC)
154 int i = ((t - low) << 4) + low; // insert 000 at position INC
155 bool reverse = ((_dir & i) == 0); // asc/desc order
156
157 // Load
158 Tkey x[16];
159#pragma unroll
160 for (int k = 0; k < 16; k++)
161 x[k] = m[k * _inc + i];
162
163 Tval idx[16];
164#pragma unroll
165 for (int k = 0; k < 16; k++)
166 idx[k] = id[k * _inc + i];
167
168 // Sort
169 order_stencil<16>(x, idx, 0, reverse);
170
171// Store
172#pragma unroll
173 for (int k = 0; k < 16; k++)
174 m[k * _inc + i] = x[k];
175#pragma unroll
176 for (int k = 0; k < 16; k++)
177 id[k * _inc + i] = idx[k];
178 }
179
180 template<>
181 inline void order_kernel<8>(
182 Tkey *__restrict__ m, Tval *__restrict__ id, u32 inc, u32 length, i32 t) {
183 u32 _inc = inc;
184 u32 _dir = length << 1;
185
186 _inc >>= 2;
187 int low = t & (_inc - 1); // low order bits (below INC)
188 int i = ((t - low) << 3) + low; // insert 000 at position INC
189 bool reverse = ((_dir & i) == 0); // asc/desc order
190
191 // Load
192 Tkey x[8];
193#pragma unroll
194 for (int k = 0; k < 8; k++)
195 x[k] = m[k * _inc + i];
196
197 Tval idx[8];
198#pragma unroll
199 for (int k = 0; k < 8; k++)
200 idx[k] = id[k * _inc + i];
201
202 // Sort
203 order_stencil<8>(x, idx, 0, reverse);
204
205// Store
206#pragma unroll
207 for (int k = 0; k < 8; k++)
208 m[k * _inc + i] = x[k];
209#pragma unroll
210 for (int k = 0; k < 8; k++)
211 id[k * _inc + i] = idx[k];
212 }
213
214 template<>
215 inline void order_kernel<4>(
216 Tkey *__restrict__ m, Tval *__restrict__ id, u32 inc, u32 length, i32 t) {
217 u32 _inc = inc;
218 u32 _dir = length << 1;
219
220 _inc >>= 1;
221 int low = t & (_inc - 1); // low order bits (below INC)
222 int i = ((t - low) << 2) + low; // insert 00 at position INC
223 bool reverse = ((_dir & i) == 0); // asc/desc order
224
225 // Load
226 Tkey x0 = m[0 + i];
227 Tkey x1 = m[_inc + i];
228 Tkey x2 = m[2 * _inc + i];
229 Tkey x3 = m[3 * _inc + i];
230
231 Tval idx0 = id[0 + i];
232 Tval idx1 = id[_inc + i];
233 Tval idx2 = id[2 * _inc + i];
234 Tval idx3 = id[3 * _inc + i];
235
236 // Sort
237 _order(x0, x2, idx0, idx2, reverse);
238 _order(x1, x3, idx1, idx3, reverse);
239 _order(x0, x1, idx0, idx1, reverse);
240 _order(x2, x3, idx2, idx3, reverse);
241
242 // Store
243 m[0 + i] = x0;
244 m[_inc + i] = x1;
245 m[2 * _inc + i] = x2;
246 m[3 * _inc + i] = x3;
247
248 id[0 + i] = idx0;
249 id[_inc + i] = idx1;
250 id[2 * _inc + i] = idx2;
251 id[3 * _inc + i] = idx3;
252 }
253
254 template<>
255 inline void order_kernel<2>(
256 Tkey *__restrict__ m, Tval *__restrict__ id, u32 inc, u32 length, i32 t) {
257 u32 _inc = inc;
258 u32 _dir = length << 1;
259
260 int low = t & (_inc - 1); // low order bits (below INC)
261 int i = (t << 1) - low; // insert 0 at position INC
262 bool reverse = ((_dir & i) == 0); // asc/desc order
263
264 u32 addr_1 = 0 + i;
265 u32 addr_2 = _inc + i;
266
267 // Load
268 Tkey x0 = m[addr_1];
269 Tkey x1 = m[addr_2];
270 Tval idx0 = id[addr_1];
271 Tval idx1 = id[addr_2];
272
273 // Sort
274 _order(x0, x1, idx0, idx1, reverse);
275
276 // Store
277 m[addr_1] = x0;
278 m[addr_2] = x1;
279 id[addr_1] = idx0;
280 id[addr_2] = idx1;
281 }
282 };
283
284 template<class Tkey, class Tval, u32 MaxStencilSize>
285 void sort_by_key_bitonic_updated_usm(
286 const sham::DeviceScheduler_ptr &sched,
288 sham::DeviceBuffer<Tval> &buf_values,
289 u32 len) {
290
291 if (!shambase::is_pow_of_two(len)) {
293 "this algorithm can only be used with length that are powers of two");
294 }
295
296 using B = OrderingPrimitive<Tkey, Tval>;
297
298 for (u32 length = 1; length < len; length <<= 1) {
299 u32 inc = length;
300 while (inc > 0) {
301 // log("inc : %d\n",inc);
302 // int ninc = 1;
303 u32 ninc = 0;
304
305 // B32 sort kernel is less performant than the B16 because of cache size
306 if constexpr (MaxStencilSize >= 32) {
307 if (inc >= 16 && ninc == 0) {
308 ninc = 5;
309 unsigned int nThreads = len >> ninc;
310
312 sched->get_queue(),
314 sham::MultiRef{buf_key, buf_values},
315 nThreads,
316 [=](u64 gid, Tkey *m, Tval *id) {
317 B::template order_kernel<32>(m, id, inc, length, gid);
318 });
319 }
320 }
321
322 if constexpr (MaxStencilSize >= 16) {
323 if (inc >= 8 && ninc == 0) {
324 ninc = 4;
325 unsigned int nThreads = len >> ninc;
326
328 sched->get_queue(),
330 sham::MultiRef{buf_key, buf_values},
331 nThreads,
332 [=](u64 gid, Tkey *m, Tval *id) {
333 B::template order_kernel<16>(m, id, inc, length, gid);
334 });
335
336 // sort_kernel_B8(arg_eq,* buf_key->buf,*
337 // particles::buf_ids->buf,inc,length<<1);//.wait();
338 }
339 }
340
341 if constexpr (MaxStencilSize >= 8) {
342 // B8
343 if (inc >= 4 && ninc == 0) {
344 ninc = 3;
345 unsigned int nThreads = len >> ninc;
346
348 sched->get_queue(),
350 sham::MultiRef{buf_key, buf_values},
351 nThreads,
352 [=](u64 gid, Tkey *m, Tval *id) {
353 B::template order_kernel<8>(m, id, inc, length, gid);
354 });
355
356 // sort_kernel_B8(arg_eq,* buf_key->buf,*
357 // particles::buf_ids->buf,inc,length<<1);//.wait();
358 }
359 }
360
361 if constexpr (MaxStencilSize >= 4) {
362 // B4
363 if (inc >= 2 && ninc == 0) {
364 ninc = 2;
365 unsigned int nThreads = len >> ninc;
366
368 sched->get_queue(),
370 sham::MultiRef{buf_key, buf_values},
371 nThreads,
372 [=](u64 gid, Tkey *m, Tval *id) {
373 B::template order_kernel<4>(m, id, inc, length, gid);
374 });
375 }
376 }
377
378 // B2
379 if (ninc == 0) {
380 ninc = 1;
381 unsigned int nThreads = len >> ninc;
382
384 sched->get_queue(),
386 sham::MultiRef{buf_key, buf_values},
387 nThreads,
388 [=](u64 gid, Tkey *m, Tval *id) {
389 B::template order_kernel<2>(m, id, inc, length, gid);
390 });
391 }
392
393 inc >>= ninc;
394 }
395 }
396 }
397
398 template void sort_by_key_bitonic_updated_usm<u32, u32, 16>(
399 const sham::DeviceScheduler_ptr &sched,
401 sham::DeviceBuffer<u32> &buf_values,
402 u32 len);
403
404 template void sort_by_key_bitonic_updated_usm<u64, u32, 16>(
405 const sham::DeviceScheduler_ptr &sched,
407 sham::DeviceBuffer<u32> &buf_values,
408 u32 len);
409
410 template void sort_by_key_bitonic_updated_usm<u32, u32, 8>(
411 const sham::DeviceScheduler_ptr &sched,
413 sham::DeviceBuffer<u32> &buf_values,
414 u32 len);
415
416 template void sort_by_key_bitonic_updated_usm<u64, u32, 8>(
417 const sham::DeviceScheduler_ptr &sched,
419 sham::DeviceBuffer<u32> &buf_values,
420 u32 len);
421
422 template void sort_by_key_bitonic_updated_usm<u32, u32, 32>(
423 const sham::DeviceScheduler_ptr &sched,
425 sham::DeviceBuffer<u32> &buf_values,
426 u32 len);
427
428 template void sort_by_key_bitonic_updated_usm<u64, u32, 32>(
429 const sham::DeviceScheduler_ptr &sched,
431 sham::DeviceBuffer<u32> &buf_values,
432 u32 len);
433
434 template void sort_by_key_bitonic_updated_usm<f32, f32, 32>(
435 const sham::DeviceScheduler_ptr &sched,
437 sham::DeviceBuffer<f32> &buf_values,
438 u32 len);
439
440 template void sort_by_key_bitonic_updated_usm<f64, f64, 32>(
441 const sham::DeviceScheduler_ptr &sched,
443 sham::DeviceBuffer<f64> &buf_values,
444 u32 len);
445
446 template void sort_by_key_bitonic_updated_usm<f32, f32, 16>(
447 const sham::DeviceScheduler_ptr &sched,
449 sham::DeviceBuffer<f32> &buf_values,
450 u32 len);
451
452 template void sort_by_key_bitonic_updated_usm<f64, f64, 16>(
453 const sham::DeviceScheduler_ptr &sched,
455 sham::DeviceBuffer<f64> &buf_values,
456 u32 len);
457
458} // namespace shamalgs::algorithm::details
std::int8_t i8
8 bit integer
std::uint32_t u32
32 bit unsigned integer
std::uint64_t u64
64 bit unsigned integer
std::int16_t i16
16 bit integer
std::int32_t i32
32 bit integer
main include file for the shamalgs algorithms
A buffer allocated in USM (Unified Shared Memory)
void kernel_call_u64(sham::DeviceQueue &q, RefIn in, RefOut in_out, u64 n, Functor &&func, SourceLocation &&callsite=SourceLocation{})
u64 indexed variant of kernel_call
namespace to store algorithms implemented by shamalgs
constexpr bool is_pow_of_two(T v) noexcept
determine if v is a power of two and check if v==0 Source : https://graphics.stanford....
Definition integer.hpp:49
void throw_with_loc(std::string message, SourceLocation loc=SourceLocation{})
Throw an exception and append the source location to it.
A class that references multiple buffers or similar objects.