26 template<
class Tkey,
class Tval>
27 struct OrderingPrimitive {
29 inline static void _order(Tkey &a, Tkey &b, Tval &va, Tval &vb,
bool reverse) {
30 bool swap = reverse ^ (a < b);
35 a = (swap) ? auxb : auxa;
36 b = (swap) ? auxa : auxb;
37 va = (swap) ? auxidb : auxida;
38 vb = (swap) ? auxida : auxidb;
41 inline static void _orderV(
42 Tkey *__restrict__ x, Tval *__restrict__ vx,
u32 a,
u32 b,
bool reverse) {
43 bool swap = reverse ^ (x[a] < x[b]);
48 x[a] = (swap) ? auxb : auxa;
49 x[b] = (swap) ? auxa : auxb;
50 vx[a] = (swap) ? auxidb : auxida;
51 vx[b] = (swap) ? auxida : auxidb;
54 template<u32 stencil_size>
55 static void order_stencil(Tkey *__restrict__ x, Tval *__restrict__ vx,
u32 a,
bool reverse);
58 inline void order_stencil<2>(
59 Tkey *__restrict__ x, Tval *__restrict__ vx,
u32 a,
bool reverse) {
60 _orderV(x, vx, a, a + 1, reverse);
64 inline void order_stencil<4>(
65 Tkey *__restrict__ x, Tval *__restrict__ vx,
u32 a,
bool reverse) {
67 for (
int i4 = 0; i4 < 2; i4++) {
68 _orderV(x, vx, a + i4, a + i4 + 2, reverse);
70 order_stencil<2>(x, vx, a, reverse);
71 order_stencil<2>(x, vx, a + 2, reverse);
75 inline void order_stencil<8>(
76 Tkey *__restrict__ x, Tval *__restrict__ vx,
u32 a,
bool reverse) {
78 for (
int i8 = 0;
i8 < 4;
i8++) {
79 _orderV(x, vx, a +
i8, a +
i8 + 4, reverse);
81 order_stencil<4>(x, vx, a, reverse);
82 order_stencil<4>(x, vx, a + 4, reverse);
86 inline void order_stencil<16>(
87 Tkey *__restrict__ x, Tval *__restrict__ vx,
u32 a,
bool reverse) {
90 _orderV(x, vx, a +
i16, a +
i16 + 8, reverse);
92 order_stencil<8>(x, vx, a, reverse);
93 order_stencil<8>(x, vx, a + 8, reverse);
97 inline void order_stencil<32>(
98 Tkey *__restrict__ x, Tval *__restrict__ vx,
u32 a,
bool reverse) {
101 _orderV(x, vx, a +
i32, a +
i32 + 16, reverse);
103 order_stencil<16>(x, vx, a, reverse);
104 order_stencil<16>(x, vx, a + 16, reverse);
107 template<u32 stencil_size>
108 static void order_kernel(
109 Tkey *__restrict__ m, Tval *__restrict__
id,
u32 inc,
u32 length,
i32 t);
112 inline void order_kernel<32>(
113 Tkey *__restrict__ m, Tval *__restrict__
id,
u32 inc,
u32 length,
i32 t) {
115 u32 _dir = length << 1U;
118 int low = t & (_inc - 1);
119 int i = ((t - low) << 5) + low;
120 bool reverse = ((_dir & i) == 0);
125 for (
int k = 0; k < 32; k++)
126 x[k] = m[k * _inc + i];
130 for (
int k = 0; k < 32; k++)
131 idx[k] =
id[k * _inc + i];
134 order_stencil<32>(x, idx, 0, reverse);
138 for (
int k = 0; k < 32; k++)
139 m[k * _inc + i] = x[k];
141 for (
int k = 0; k < 32; k++)
142 id[k * _inc + i] = idx[k];
146 inline void order_kernel<16>(
147 Tkey *__restrict__ m, Tval *__restrict__
id,
u32 inc,
u32 length,
i32 t) {
150 u32 _dir = length << 1;
153 int low = t & (_inc - 1);
154 int i = ((t - low) << 4) + low;
155 bool reverse = ((_dir & i) == 0);
160 for (
int k = 0; k < 16; k++)
161 x[k] = m[k * _inc + i];
165 for (
int k = 0; k < 16; k++)
166 idx[k] =
id[k * _inc + i];
169 order_stencil<16>(x, idx, 0, reverse);
173 for (
int k = 0; k < 16; k++)
174 m[k * _inc + i] = x[k];
176 for (
int k = 0; k < 16; k++)
177 id[k * _inc + i] = idx[k];
181 inline void order_kernel<8>(
182 Tkey *__restrict__ m, Tval *__restrict__
id,
u32 inc,
u32 length,
i32 t) {
184 u32 _dir = length << 1;
187 int low = t & (_inc - 1);
188 int i = ((t - low) << 3) + low;
189 bool reverse = ((_dir & i) == 0);
194 for (
int k = 0; k < 8; k++)
195 x[k] = m[k * _inc + i];
199 for (
int k = 0; k < 8; k++)
200 idx[k] =
id[k * _inc + i];
203 order_stencil<8>(x, idx, 0, reverse);
207 for (
int k = 0; k < 8; k++)
208 m[k * _inc + i] = x[k];
210 for (
int k = 0; k < 8; k++)
211 id[k * _inc + i] = idx[k];
215 inline void order_kernel<4>(
216 Tkey *__restrict__ m, Tval *__restrict__
id,
u32 inc,
u32 length,
i32 t) {
218 u32 _dir = length << 1;
221 int low = t & (_inc - 1);
222 int i = ((t - low) << 2) + low;
223 bool reverse = ((_dir & i) == 0);
227 Tkey x1 = m[_inc + i];
228 Tkey x2 = m[2 * _inc + i];
229 Tkey x3 = m[3 * _inc + i];
231 Tval idx0 =
id[0 + i];
232 Tval idx1 =
id[_inc + i];
233 Tval idx2 =
id[2 * _inc + i];
234 Tval idx3 =
id[3 * _inc + i];
237 _order(x0, x2, idx0, idx2, reverse);
238 _order(x1, x3, idx1, idx3, reverse);
239 _order(x0, x1, idx0, idx1, reverse);
240 _order(x2, x3, idx2, idx3, reverse);
245 m[2 * _inc + i] = x2;
246 m[3 * _inc + i] = x3;
250 id[2 * _inc + i] = idx2;
251 id[3 * _inc + i] = idx3;
255 inline void order_kernel<2>(
256 Tkey *__restrict__ m, Tval *__restrict__
id,
u32 inc,
u32 length,
i32 t) {
258 u32 _dir = length << 1;
260 int low = t & (_inc - 1);
261 int i = (t << 1) - low;
262 bool reverse = ((_dir & i) == 0);
265 u32 addr_2 = _inc + i;
270 Tval idx0 =
id[addr_1];
271 Tval idx1 =
id[addr_2];
274 _order(x0, x1, idx0, idx1, reverse);
284 template<
class Tkey,
class Tval, u32 MaxStencilSize>
285 void sort_by_key_bitonic_updated_usm(
286 const sham::DeviceScheduler_ptr &sched,
293 "this algorithm can only be used with length that are powers of two");
296 using B = OrderingPrimitive<Tkey, Tval>;
298 for (
u32 length = 1; length < len; length <<= 1) {
306 if constexpr (MaxStencilSize >= 32) {
307 if (inc >= 16 && ninc == 0) {
309 unsigned int nThreads = len >> ninc;
316 [=](
u64 gid, Tkey *m, Tval *
id) {
317 B::template order_kernel<32>(m, id, inc, length, gid);
322 if constexpr (MaxStencilSize >= 16) {
323 if (inc >= 8 && ninc == 0) {
325 unsigned int nThreads = len >> ninc;
332 [=](
u64 gid, Tkey *m, Tval *
id) {
333 B::template order_kernel<16>(m, id, inc, length, gid);
341 if constexpr (MaxStencilSize >= 8) {
343 if (inc >= 4 && ninc == 0) {
345 unsigned int nThreads = len >> ninc;
352 [=](
u64 gid, Tkey *m, Tval *
id) {
353 B::template order_kernel<8>(m, id, inc, length, gid);
361 if constexpr (MaxStencilSize >= 4) {
363 if (inc >= 2 && ninc == 0) {
365 unsigned int nThreads = len >> ninc;
372 [=](
u64 gid, Tkey *m, Tval *
id) {
373 B::template order_kernel<4>(m, id, inc, length, gid);
381 unsigned int nThreads = len >> ninc;
388 [=](
u64 gid, Tkey *m, Tval *
id) {
389 B::template order_kernel<2>(m, id, inc, length, gid);
398 template void sort_by_key_bitonic_updated_usm<u32, u32, 16>(
399 const sham::DeviceScheduler_ptr &sched,
404 template void sort_by_key_bitonic_updated_usm<u64, u32, 16>(
405 const sham::DeviceScheduler_ptr &sched,
410 template void sort_by_key_bitonic_updated_usm<u32, u32, 8>(
411 const sham::DeviceScheduler_ptr &sched,
416 template void sort_by_key_bitonic_updated_usm<u64, u32, 8>(
417 const sham::DeviceScheduler_ptr &sched,
422 template void sort_by_key_bitonic_updated_usm<u32, u32, 32>(
423 const sham::DeviceScheduler_ptr &sched,
428 template void sort_by_key_bitonic_updated_usm<u64, u32, 32>(
429 const sham::DeviceScheduler_ptr &sched,
434 template void sort_by_key_bitonic_updated_usm<f32, f32, 32>(
435 const sham::DeviceScheduler_ptr &sched,
440 template void sort_by_key_bitonic_updated_usm<f64, f64, 32>(
441 const sham::DeviceScheduler_ptr &sched,
446 template void sort_by_key_bitonic_updated_usm<f32, f32, 16>(
447 const sham::DeviceScheduler_ptr &sched,
452 template void sort_by_key_bitonic_updated_usm<f64, f64, 16>(
453 const sham::DeviceScheduler_ptr &sched,
std::int8_t i8
8 bit integer
std::uint32_t u32
32 bit unsigned integer
std::uint64_t u64
64 bit unsigned integer
std::int16_t i16
16 bit integer
std::int32_t i32
32 bit integer
main include file for the shamalgs algorithms
A buffer allocated in USM (Unified Shared Memory)
void kernel_call_u64(sham::DeviceQueue &q, RefIn in, RefOut in_out, u64 n, Functor &&func, SourceLocation &&callsite=SourceLocation{})
u64 indexed variant of kernel_call
namespace to store algorithms implemented by shamalgs
constexpr bool is_pow_of_two(T v) noexcept
determine if v is a power of two and check if v==0 Source : https://graphics.stanford....
void throw_with_loc(std::string message, SourceLocation loc=SourceLocation{})
Throw an exception and append the source location to it.
A class that references multiple buffers or similar objects.