25 inline sycl::buffer<T> flatten_buffer(sycl::queue &q, sycl::buffer<T> &buf_in,
u64 len) {
26 sycl::buffer<T> ret(len);
28 q.submit([=, &buf_in](sycl::handler &cgh) {
29 sycl::accessor acc_in{buf_in, cgh, sycl::read_only};
30 sycl::accessor acc_out{ret, cgh, sycl::write_only, sycl::no_init};
31 cgh.parallel_for(sycl::range<1>{len}, [=](sycl::item<1> id) {
32 acc_out[id] = acc_in[id];
37 template<
class T,
int n>
38 inline sycl::buffer<T> flatten_buffer(
39 sycl::queue &q, sycl::buffer<sycl::vec<T, n>> &buf_in,
u64 len) {
40 sycl::buffer<T> ret(len * n);
42 q.submit([=, &buf_in, &ret](sycl::handler &cgh) {
43 sycl::accessor acc_in{buf_in, cgh, sycl::read_only};
44 sycl::accessor acc_out{ret, cgh, sycl::write_only, sycl::no_init};
46 cgh.parallel_for(sycl::range<1>{len}, [=](sycl::item<1> id) {
47 u32 idx =
id.get_linear_id() * n;
49 if constexpr (n == 2) {
50 acc_out[idx] = acc_in[id].x();
51 acc_out[idx + 1] = acc_in[id].y();
54 if constexpr (n == 3) {
55 acc_out[idx] = acc_in[id].x();
56 acc_out[idx + 1] = acc_in[id].y();
57 acc_out[idx + 2] = acc_in[id].z();
60 if constexpr (n == 4) {
61 acc_out[idx] = acc_in[id].x();
62 acc_out[idx + 1] = acc_in[id].y();
63 acc_out[idx + 2] = acc_in[id].z();
64 acc_out[idx + 3] = acc_in[id].w();
67 if constexpr (n == 8) {
68 acc_out[idx] = acc_in[id].s0();
69 acc_out[idx + 1] = acc_in[id].s1();
70 acc_out[idx + 2] = acc_in[id].s2();
71 acc_out[idx + 3] = acc_in[id].s3();
72 acc_out[idx + 4] = acc_in[id].s4();
73 acc_out[idx + 5] = acc_in[id].s5();
74 acc_out[idx + 6] = acc_in[id].s6();
75 acc_out[idx + 7] = acc_in[id].s7();
78 if constexpr (n == 16) {
79 acc_out[idx] = acc_in[id].s0();
80 acc_out[idx + 1] = acc_in[id].s1();
81 acc_out[idx + 2] = acc_in[id].s2();
82 acc_out[idx + 3] = acc_in[id].s3();
83 acc_out[idx + 4] = acc_in[id].s4();
84 acc_out[idx + 5] = acc_in[id].s5();
85 acc_out[idx + 6] = acc_in[id].s6();
86 acc_out[idx + 7] = acc_in[id].s7();
87 acc_out[idx + 8] = acc_in[id].s8();
88 acc_out[idx + 9] = acc_in[id].s9();
89 acc_out[idx + 10] = acc_in[id].sA();
90 acc_out[idx + 11] = acc_in[id].sB();
91 acc_out[idx + 12] = acc_in[id].sC();
92 acc_out[idx + 13] = acc_in[id].sD();
93 acc_out[idx + 14] = acc_in[id].sE();
94 acc_out[idx + 15] = acc_in[id].sF();
std::uint32_t u32
32 bit unsigned integer
std::uint64_t u64
64 bit unsigned integer
memory manipulation algorithms