Shamrock 2025.10.0
Astrophysical Code
Loading...
Searching...
No Matches
numeric.cpp
Go to the documentation of this file.
1// -------------------------------------------------------//
2//
3// SHAMROCK code for hydrodynamics
4// Copyright (c) 2021-2026 Timothée David--Cléris <tim.shamrock@proton.me>
5// SPDX-License-Identifier: CeCILL Free Software License Agreement v2.1
6// Shamrock is licensed under the CeCILL 2.1 License, see LICENSE for more information
7//
8// -------------------------------------------------------//
9
19#include "shambase/assert.hpp"
20#include "shambase/integer.hpp"
30#include <utility>
31
32namespace shamalgs::numeric {
33
34 template<class T>
35 sycl::buffer<T> scan_exclusive(sycl::queue &q, sycl::buffer<T> &buf1, u32 len) {
36#ifdef __MACH__ // decoupled lookback perf on mac os is awful
37 return details::exclusive_sum_fallback(q, buf1, len);
38#else
39 #ifdef SYCL2020_FEATURE_GROUP_REDUCTION
40 return details::exclusive_sum_atomic_decoupled_v5<T, 512>(q, buf1, len);
41 #else
42 return details::exclusive_sum_fallback(q, buf1, len);
43 #endif
44#endif
45 }
46
47 template<class T>
49 sham::DeviceScheduler_ptr sched, sham::DeviceBuffer<T> &buf1, u32 len) {
50#ifdef __MACH__ // decoupled lookback perf on mac os is awful
51 return details::exclusive_sum_fallback_usm(sched, buf1, len);
52#else
53 #ifdef SYCL2020_FEATURE_GROUP_REDUCTION
54 return details::exclusive_sum_atomic_decoupled_v5_usm<T, 512>(sched, buf1, len);
55 #else
56 return details::exclusive_sum_fallback_usm(sched, buf1, len);
57 #endif
58#endif
59 }
60
61 template<class T>
62 sycl::buffer<T> scan_inclusive(sycl::queue &q, sycl::buffer<T> &buf1, u32 len) {
63 return details::inclusive_sum_fallback(q, buf1, len);
64 }
65
66 template<class T>
67 void scan_exclusive_in_place(sycl::queue &q, sycl::buffer<T> &buf1, u32 len) {
68 buf1 = details::exclusive_sum_atomic_decoupled_v5<T, 256>(q, buf1, len);
69 }
70
71 template<class T>
72 void scan_inclusive_in_place(sycl::queue &q, sycl::buffer<T> &buf1, u32 len) {
73 buf1 = details::inclusive_sum_fallback(q, buf1, len);
74 }
75
76 template sycl::buffer<u32> scan_exclusive(sycl::queue &q, sycl::buffer<u32> &buf1, u32 len);
78 sham::DeviceScheduler_ptr sched, sham::DeviceBuffer<u32> &buf1, u32 len);
79 template sycl::buffer<u32> scan_inclusive(sycl::queue &q, sycl::buffer<u32> &buf1, u32 len);
80
81 template void scan_exclusive_in_place(sycl::queue &q, sycl::buffer<u32> &buf1, u32 len);
82 template void scan_inclusive_in_place(sycl::queue &q, sycl::buffer<u32> &buf1, u32 len);
83
84 std::tuple<std::optional<sycl::buffer<u32>>, u32> stream_compact(
85 sycl::queue &q, sycl::buffer<u32> &buf_flags, u32 len) {
86 return details::stream_compact_excl_scan(q, buf_flags, len);
87 };
88
90 const sham::DeviceScheduler_ptr &sched, sham::DeviceBuffer<u32> &buf_flags, u32 len) {
91 return details::stream_compact_excl_scan(sched, buf_flags, len);
92 }
93
94 template<class Tret, class T>
96 const sham::DeviceScheduler_ptr &sched,
97 const sham::DeviceBuffer<T> &bin_edges,
98 u64 nbins,
99 const sham::DeviceBuffer<T> &values,
100 u32 len) {
101
102 SHAM_ASSERT(nbins > 1); // at least a sup and a inf
103 SHAM_ASSERT(bin_edges.get_size() == nbins + 1);
104
106 counts.fill(0);
107
108 if (len == 0) {
109 return counts;
110 }
111
112 auto &q = shambase::get_check_ref(sched).get_queue();
113
115 q,
116 sham::MultiRef{values, bin_edges},
117 sham::MultiRef{counts},
118 len,
119 [nbins](
120 u32 i,
121 const T *__restrict values,
122 const T *__restrict bin_edges,
123 Tret *__restrict counts) {
124 // Only count values within [bin_edges[0], bin_edges[nbins])
125 if (values[i] < bin_edges[0] || values[i] >= bin_edges[nbins]) {
126 return;
127 }
128
129 u32 start_range = 0;
130 u32 end_range = nbins + 1;
131
132 while (end_range - start_range > 1) {
133 u32 mid_range = (start_range + end_range) / 2;
134
135 if (values[i] < bin_edges[mid_range]) { // mid_range is a sup
136 end_range = mid_range;
137 } else { // mid_range is an inf
138 start_range = mid_range;
139 }
140 }
141
142 SHAM_ASSERT(end_range == start_range + 1);
143
144 sycl::atomic_ref<
145 Tret,
146 sycl::memory_order_relaxed,
147 sycl::memory_scope_device,
148 sycl::access::address_space::global_space>
149 cnt(counts[start_range]);
150
151 cnt++;
152 });
153
154 return counts;
155 }
156
157 template sham::DeviceBuffer<u64> device_histogram<u64, f64>(
158 const sham::DeviceScheduler_ptr &sched,
159 const sham::DeviceBuffer<f64> &bin_edges,
160 u64 nbins,
161 const sham::DeviceBuffer<f64> &values,
162 u32 len);
163 template sham::DeviceBuffer<u64> device_histogram<u64, f32>(
164 const sham::DeviceScheduler_ptr &sched,
165 const sham::DeviceBuffer<f32> &bin_edges,
166 u64 nbins,
167 const sham::DeviceBuffer<f32> &values,
168 u32 len);
169 template sham::DeviceBuffer<u32> device_histogram<u32, f64>(
170 const sham::DeviceScheduler_ptr &sched,
171 const sham::DeviceBuffer<f64> &bin_edges,
172 u64 nbins,
173 const sham::DeviceBuffer<f64> &values,
174 u32 len);
175 template sham::DeviceBuffer<u32> device_histogram<u32, f32>(
176 const sham::DeviceScheduler_ptr &sched,
177 const sham::DeviceBuffer<f32> &bin_edges,
178 u64 nbins,
179 const sham::DeviceBuffer<f32> &values,
180 u32 len);
181
182 template<class T>
184 const sham::DeviceScheduler_ptr &sched,
185 const sham::DeviceBuffer<T> &bin_edges,
186 u64 nbins,
187 const sham::DeviceBuffer<T> &values, // ie f(r)
188 const sham::DeviceBuffer<T> &keys, // ie r
189 u32 len) { // ie return <f(r)>_r
190
191 auto &q = shambase::get_check_ref(sched).get_queue();
192
193 auto value_filter = [&]() {
194 if (len > 0) {
195
196 // filter values
197 sham::DeviceBuffer<u32> key_filter(keys.get_size(), sched);
198
200 q,
201 sham::MultiRef{keys, bin_edges},
202 sham::MultiRef{key_filter},
203 len,
204 [nbins](
205 u32 i,
206 const T *__restrict keys,
207 const T *__restrict bin_edges,
208 u32 *__restrict key_filter) {
209 // Only count keys within [bin_edges[0], bin_edges[nbins])
210 if (keys[i] < bin_edges[0] || keys[i] >= bin_edges[nbins]) {
211 key_filter[i] = 0;
212 } else {
213 key_filter[i] = 1;
214 }
215 });
216
217 // compact
218 sham::DeviceBuffer<u32> valid_key_idxs = stream_compact(sched, key_filter, len);
219
220 return valid_key_idxs;
221 } else {
222 return sham::DeviceBuffer<u32>(0, sched);
223 }
224 };
225
226 sham::DeviceBuffer<u32> valid_key_idxs = value_filter();
227
228 u32 valid_key_count = valid_key_idxs.get_size();
229
230 // make the buffer with all the valid keys
231 sham::DeviceBuffer<T> valid_keys(valid_key_count, sched);
232 sham::DeviceBuffer<T> valid_values(valid_key_count, sched);
233
234 if (valid_key_count > 0) {
236 q,
237 sham::MultiRef{keys, values, valid_key_idxs},
238 sham::MultiRef{valid_keys, valid_values},
239 valid_key_count,
240 [](u32 i,
241 const T *__restrict keys,
242 const T *__restrict values,
243 const u32 *__restrict valid_keys_idxs,
244 T *__restrict valid_keys,
245 T *__restrict valid_values) {
246 u32 src_key = valid_keys_idxs[i];
247 valid_keys[i] = keys[src_key];
248 valid_values[i] = values[src_key];
249 });
250 }
251
252 // histogram standard
253 sham::DeviceBuffer<u32> bin_counts
254 = device_histogram<u32>(sched, bin_edges, nbins, valid_keys, valid_key_count);
255
256 bin_counts.expand(1);
257 bin_counts.set_val_at_idx(bin_counts.get_size() - 1, 0);
258
259 // exclusive scan
260 // bin_ids[i] starts at offset[i] and ends at offset[i+1]
261 sham::DeviceBuffer<u32> offsets_bins
262 = scan_exclusive(sched, bin_counts, bin_counts.get_size());
263
264 SHAM_ASSERT(offsets_bins.get_val_at_idx(offsets_bins.get_size() - 1) == valid_key_count);
265
266 if (valid_key_count > 0) {
267 // sort need 2^n as length
268 u32 pow2_len_key = shambase::roundup_pow2(valid_key_count);
269 {
270 if (pow2_len_key > valid_key_count) {
271 valid_keys.resize(pow2_len_key);
272 valid_values.resize(pow2_len_key);
273
275 q,
277 sham::MultiRef{valid_keys, valid_values},
278 pow2_len_key - valid_key_count,
279 [offset_start = valid_key_count](
280 u32 i, T *__restrict valid_keys, T *__restrict valid_values) {
281 u32 key_id = offset_start + i;
282 valid_keys[key_id] = shambase::get_max<T>();
283 valid_values[key_id] = shambase::get_max<T>();
284 });
285 }
286 }
287
288 // how to be a patate? Resize buffers to diligently become powers of 2, and don't update
289 // the variable holding their length
290 shamalgs::algorithm::sort_by_key(sched, valid_keys, valid_values, pow2_len_key);
291 }
292
293 return {std::move(valid_values), std::move(offsets_bins)};
294 }
295
297 const sham::DeviceScheduler_ptr &sched,
298 const sham::DeviceBuffer<f64> &bin_edges,
299 u64 nbins,
300 const sham::DeviceBuffer<f64> &values,
301 const sham::DeviceBuffer<f64> &keys,
302 u32 len);
304 const sham::DeviceScheduler_ptr &sched,
305 const sham::DeviceBuffer<f32> &bin_edges,
306 u64 nbins,
307 const sham::DeviceBuffer<f32> &values,
308 const sham::DeviceBuffer<f32> &keys,
309 u32 len);
310
311} // namespace shamalgs::numeric
std::uint32_t u32
32 bit unsigned integer
std::uint64_t u64
64 bit unsigned integer
Shamrock assertion utility.
#define SHAM_ASSERT(x)
Shorthand for SHAM_ASSERT_NAMED without a message.
Definition assert.hpp:67
A buffer allocated in USM (Unified Shared Memory)
DeviceQueue & get_queue() const
Gets the DeviceQueue associated with the held allocation.
void resize(size_t new_size, bool keep_data=true)
Resizes the buffer to a given size.
void fill(T value, std::array< size_t, 2 > idx_range)
Fill a subpart of the buffer with a given value.
T get_val_at_idx(size_t idx) const
Get the value at a given index in the buffer.
size_t get_size() const
Gets the number of elements in the buffer.
void expand(u32 add_sz)
Expand the buffer by add_sz elements.
main include file for the shamalgs algorithms
void kernel_call(sham::DeviceQueue &q, RefIn in, RefOut in_out, u32 n, Functor &&func, SourceLocation &&callsite=SourceLocation{})
Submit a kernel to a SYCL queue.
void sort_by_key(sycl::queue &q, sycl::buffer< Tkey > &buf_key, sycl::buffer< Tval > &buf_values, u32 len)
Sort the buffer according to the key order.
Definition algorithm.hpp:41
namespace containing the numeric algorithms of shamalgs
BinnedCompute< T > binned_init_compute(const sham::DeviceScheduler_ptr &sched, const sham::DeviceBuffer< T > &bin_edges, u64 nbins, const sham::DeviceBuffer< T > &values, const sham::DeviceBuffer< T > &keys, u32 len)
Prepare binned data for per-bin computation.
Definition numeric.cpp:183
sham::DeviceBuffer< Tret > device_histogram(const sham::DeviceScheduler_ptr &sched, const sham::DeviceBuffer< T > &bin_edges, u64 nbins, const sham::DeviceBuffer< T > &values, u32 len)
Compute the histogram of values between bin_edges.
Definition numeric.cpp:95
sycl::buffer< T > scan_exclusive(sycl::queue &q, sycl::buffer< T > &buf1, u32 len)
Computes the exclusive sum of elements in a SYCL buffer.
Definition numeric.cpp:35
std::tuple< std::optional< sycl::buffer< u32 > >, u32 > stream_compact(sycl::queue &q, sycl::buffer< u32 > &buf_flags, u32 len)
Stream compaction algorithm.
Definition numeric.cpp:84
void throw_with_loc(std::string message, SourceLocation loc=SourceLocation{})
Throw an exception and append the source location to it.
constexpr T roundup_pow2(T v) noexcept
round up to the next power of two Source : https://graphics.stanford.edu/~seander/bithacks....
Definition integer.hpp:92
T & get_check_ref(const std::unique_ptr< T > &ptr, SourceLocation loc=SourceLocation())
Takes a std::unique_ptr and returns a reference to the object it holds. It throws a std::runtime_erro...
Definition memory.hpp:110
A class that references multiple buffers or similar objects.
Structure holding the result of binning values for further computation.
Definition numeric.hpp:216