Shamrock 2025.10.0
Astrophysical Code
Loading...
Searching...
No Matches
flatten.hpp
Go to the documentation of this file.
1// -------------------------------------------------------//
2//
3// SHAMROCK code for hydrodynamics
4// Copyright (c) 2021-2026 Timothée David--Cléris <tim.shamrock@proton.me>
5// SPDX-License-Identifier: CeCILL Free Software License Agreement v2.1
6// Shamrock is licensed under the CeCILL 2.1 License, see LICENSE for more information
7//
8// -------------------------------------------------------//
9
10#pragma once
11
22#include <type_traits>
23#include <stdexcept>
24#include <utility>
25
26namespace shamalgs::primitives {
27
38 template<class Tvec, sham::USMKindTarget target>
41
42 using Tscal = typename shambase::VectorProperties<Tvec>::component_type;
43 auto &sched = buffer.get_dev_scheduler_ptr();
44
45 if constexpr (target == sham::USMKindTarget::device) {
46
47 if constexpr (std::is_same_v<Tvec, Tscal>) {
48 return buffer.copy();
49 } else if constexpr (std::is_same_v<Tvec, sycl::vec<Tscal, 2>>) {
50
51 sham::DeviceBuffer<Tscal, target> ret(buffer.get_size() * 2, sched);
52
53 sham::EventList depends_list;
54 const Tvec *ptr_src = buffer.get_read_access(depends_list);
55 Tscal *ptr_dest = ret.get_write_access(depends_list);
56
57 sycl::event e = buffer.get_dev_scheduler().get_queue().submit(
58 depends_list, [&](sycl::handler &cgh) {
59 cgh.parallel_for(buffer.get_size(), [=](sycl::id<1> gid) {
60 Tvec tmp = ptr_src[gid];
61 ptr_dest[gid * 2 + 0] = tmp[0];
62 ptr_dest[gid * 2 + 1] = tmp[1];
63 });
64 });
65
67 buffer.complete_event_state(e);
68
69 return ret;
70
71 } else if constexpr (std::is_same_v<Tvec, sycl::vec<Tscal, 3>>) {
72
73 sham::DeviceBuffer<Tscal, target> ret(buffer.get_size() * 3, sched);
74
75 sham::EventList depends_list;
76 const Tvec *ptr_src = buffer.get_read_access(depends_list);
77 Tscal *ptr_dest = ret.get_write_access(depends_list);
78
79 sycl::event e = buffer.get_dev_scheduler().get_queue().submit(
80 depends_list, [&](sycl::handler &cgh) {
81 cgh.parallel_for(buffer.get_size(), [=](sycl::id<1> gid) {
82 Tvec tmp = ptr_src[gid];
83 ptr_dest[gid * 3 + 0] = tmp[0];
84 ptr_dest[gid * 3 + 1] = tmp[1];
85 ptr_dest[gid * 3 + 2] = tmp[2];
86 });
87 });
88
90 buffer.complete_event_state(e);
91 return ret;
92
93 } else {
95 }
96
97 } else {
99 }
100 }
101
115 template<class Tvec, sham::USMKindTarget target>
117 const sham::DeviceBuffer<typename shambase::VectorProperties<Tvec>::component_type, target>
118 &buffer) {
119
120 using Tscal = typename shambase::VectorProperties<Tvec>::component_type;
121 auto &sched = buffer.get_dev_scheduler_ptr();
122
123 if constexpr (target == sham::USMKindTarget::device) {
124
125 if constexpr (std::is_same_v<Tscal, Tvec>) {
126 return buffer.copy();
127 } else if constexpr (std::is_same_v<Tvec, sycl::vec<Tscal, 2>>) {
128
129 if (buffer.get_size() % 2 != 0) {
131 "The buffer must have an even number of elements");
132 }
133
134 sham::DeviceBuffer<Tvec, target> ret(buffer.get_size() / 2, sched);
135
136 sham::EventList depends_list;
137 const Tscal *ptr_src = buffer.get_read_access(depends_list);
138 Tvec *ptr_dest = ret.get_write_access(depends_list);
139
140 sycl::event e = buffer.get_dev_scheduler().get_queue().submit(
141 depends_list, [&](sycl::handler &cgh) {
142 cgh.parallel_for(buffer.get_size() / 2, [=](sycl::id<1> gid) {
143 ptr_dest[gid] = Tvec{ptr_src[gid * 2 + 0], ptr_src[gid * 2 + 1]};
144 });
145 });
146
147 ret.complete_event_state(e);
148 buffer.complete_event_state(e);
149
150 return ret;
151
152 } else if constexpr (std::is_same_v<Tvec, sycl::vec<Tscal, 3>>) {
153
154 if (buffer.get_size() % 3 != 0) {
156 "The buffer must have a multiple of 3 elements");
157 }
158
159 sham::DeviceBuffer<Tvec, target> ret(buffer.get_size() / 3, sched);
160
161 sham::EventList depends_list;
162 const Tscal *ptr_src = buffer.get_read_access(depends_list);
163 Tvec *ptr_dest = ret.get_write_access(depends_list);
164
165 sycl::event e = buffer.get_dev_scheduler().get_queue().submit(
166 depends_list, [&](sycl::handler &cgh) {
167 cgh.parallel_for(buffer.get_size() / 3, [=](sycl::id<1> gid) {
168 ptr_dest[gid] = Tvec{
169 ptr_src[gid * 3 + 0], ptr_src[gid * 3 + 1], ptr_src[gid * 3 + 2]};
170 });
171 });
172
173 ret.complete_event_state(e);
174 buffer.complete_event_state(e);
175
176 return ret;
177 } else {
179 }
180
181 } else {
183 }
184 }
185
186} // namespace shamalgs::primitives
A buffer allocated in USM (Unified Shared Memory)
void complete_event_state(sycl::event e) const
Complete the event state of the buffer.
T * get_write_access(sham::EventList &depends_list, SourceLocation src_loc=SourceLocation{})
Get a read-write pointer to the buffer's data.
std::shared_ptr< DeviceScheduler > & get_dev_scheduler_ptr()
Gets the Device scheduler pointer corresponding to the held allocation.
size_t get_size() const
Gets the number of elements in the buffer.
DeviceScheduler & get_dev_scheduler() const
Gets the Device scheduler corresponding to the held allocation.
const T * get_read_access(sham::EventList &depends_list, SourceLocation src_loc=SourceLocation{}) const
Get a read-only pointer to the buffer's data.
DeviceBuffer< T, target > copy() const
Copy the current buffer.
sycl::event submit(Fct &&fct)
Submits a kernel to the SYCL queue.
DeviceQueue & get_queue(u32 id=0)
Get a reference to a DeviceQueue.
Class to manage a list of SYCL events.
Definition EventList.hpp:31
This header file contains utility functions related to exception handling in the code.
@ device
Device memory.
namespace for primitive algorithm (e.g. sort, scan, reductions, ...)
sham::DeviceBuffer< Tvec, target > unflatten_buffer(const sham::DeviceBuffer< typename shambase::VectorProperties< Tvec >::component_type, target > &buffer)
Unflatten a buffer that contains a flattened vector.
Definition flatten.hpp:116
sham::DeviceBuffer< typename shambase::VectorProperties< Tvec >::component_type, target > flatten_buffer(const sham::DeviceBuffer< Tvec, target > &buffer)
Flatten a buffer of vector type into a buffer of scalar type.
Definition flatten.hpp:40
void throw_with_loc(std::string message, SourceLocation loc=SourceLocation{})
Throw an exception and append the source location to it.
void throw_unimplemented(SourceLocation loc=SourceLocation{})
Throw a std::runtime_error saying that the function is unimplemented.