37 return details::exclusive_sum_fallback(q, buf1, len);
39 #ifdef SYCL2020_FEATURE_GROUP_REDUCTION
40 return details::exclusive_sum_atomic_decoupled_v5<T, 512>(q, buf1, len);
42 return details::exclusive_sum_fallback(q, buf1, len);
51 return details::exclusive_sum_fallback_usm(sched, buf1, len);
53 #ifdef SYCL2020_FEATURE_GROUP_REDUCTION
54 return details::exclusive_sum_atomic_decoupled_v5_usm<T, 512>(sched, buf1, len);
56 return details::exclusive_sum_fallback_usm(sched, buf1, len);
62 sycl::buffer<T> scan_inclusive(sycl::queue &q, sycl::buffer<T> &buf1,
u32 len) {
63 return details::inclusive_sum_fallback(q, buf1, len);
67 void scan_exclusive_in_place(sycl::queue &q, sycl::buffer<T> &buf1,
u32 len) {
68 buf1 = details::exclusive_sum_atomic_decoupled_v5<T, 256>(q, buf1, len);
72 void scan_inclusive_in_place(sycl::queue &q, sycl::buffer<T> &buf1,
u32 len) {
73 buf1 = details::inclusive_sum_fallback(q, buf1, len);
76 template sycl::buffer<u32>
scan_exclusive(sycl::queue &q, sycl::buffer<u32> &buf1,
u32 len);
79 template sycl::buffer<u32> scan_inclusive(sycl::queue &q, sycl::buffer<u32> &buf1,
u32 len);
81 template void scan_exclusive_in_place(sycl::queue &q, sycl::buffer<u32> &buf1,
u32 len);
82 template void scan_inclusive_in_place(sycl::queue &q, sycl::buffer<u32> &buf1,
u32 len);
85 sycl::queue &q, sycl::buffer<u32> &buf_flags,
u32 len) {
86 return details::stream_compact_excl_scan(q, buf_flags, len);
91 return details::stream_compact_excl_scan(sched, buf_flags, len);
94 template<
class Tret,
class T>
96 const sham::DeviceScheduler_ptr &sched,
121 const T *__restrict values,
122 const T *__restrict bin_edges,
123 Tret *__restrict counts) {
125 if (values[i] < bin_edges[0] || values[i] >= bin_edges[nbins]) {
130 u32 end_range = nbins + 1;
132 while (end_range - start_range > 1) {
133 u32 mid_range = (start_range + end_range) / 2;
135 if (values[i] < bin_edges[mid_range]) {
136 end_range = mid_range;
138 start_range = mid_range;
146 sycl::memory_order_relaxed,
147 sycl::memory_scope_device,
148 sycl::access::address_space::global_space>
149 cnt(counts[start_range]);
158 const sham::DeviceScheduler_ptr &sched,
164 const sham::DeviceScheduler_ptr &sched,
170 const sham::DeviceScheduler_ptr &sched,
176 const sham::DeviceScheduler_ptr &sched,
184 const sham::DeviceScheduler_ptr &sched,
193 auto value_filter = [&]() {
206 const T *__restrict keys,
207 const T *__restrict bin_edges,
208 u32 *__restrict key_filter) {
210 if (keys[i] < bin_edges[0] || keys[i] >= bin_edges[nbins]) {
220 return valid_key_idxs;
234 if (valid_key_count > 0) {
241 const T *__restrict keys,
242 const T *__restrict values,
243 const u32 *__restrict valid_keys_idxs,
244 T *__restrict valid_keys,
245 T *__restrict valid_values) {
246 u32 src_key = valid_keys_idxs[i];
247 valid_keys[i] = keys[src_key];
248 valid_values[i] = values[src_key];
254 = device_histogram<u32>(sched, bin_edges, nbins, valid_keys, valid_key_count);
257 bin_counts.set_val_at_idx(bin_counts.
get_size() - 1, 0);
266 if (valid_key_count > 0) {
270 if (pow2_len_key > valid_key_count) {
271 valid_keys.
resize(pow2_len_key);
272 valid_values.
resize(pow2_len_key);
278 pow2_len_key - valid_key_count,
279 [offset_start = valid_key_count](
280 u32 i, T *__restrict valid_keys, T *__restrict valid_values) {
281 u32 key_id = offset_start + i;
293 return {std::move(valid_values), std::move(offsets_bins)};
297 const sham::DeviceScheduler_ptr &sched,
304 const sham::DeviceScheduler_ptr &sched,
std::uint32_t u32
32 bit unsigned integer
std::uint64_t u64
64 bit unsigned integer
Shamrock assertion utility.
#define SHAM_ASSERT(x)
Shorthand for SHAM_ASSERT_NAMED without a message.
A buffer allocated in USM (Unified Shared Memory)
DeviceQueue & get_queue() const
Gets the DeviceQueue associated with the held allocation.
void resize(size_t new_size, bool keep_data=true)
Resizes the buffer to a given size.
void fill(T value, std::array< size_t, 2 > idx_range)
Fill a subpart of the buffer with a given value.
T get_val_at_idx(size_t idx) const
Get the value at a given index in the buffer.
size_t get_size() const
Gets the number of elements in the buffer.
void expand(u32 add_sz)
Expand the buffer by add_sz elements.
main include file for the shamalgs algorithms
void kernel_call(sham::DeviceQueue &q, RefIn in, RefOut in_out, u32 n, Functor &&func, SourceLocation &&callsite=SourceLocation{})
Submit a kernel to a SYCL queue.
void sort_by_key(sycl::queue &q, sycl::buffer< Tkey > &buf_key, sycl::buffer< Tval > &buf_values, u32 len)
Sort the buffer according to the key order.
namespace containing the numeric algorithms of shamalgs
BinnedCompute< T > binned_init_compute(const sham::DeviceScheduler_ptr &sched, const sham::DeviceBuffer< T > &bin_edges, u64 nbins, const sham::DeviceBuffer< T > &values, const sham::DeviceBuffer< T > &keys, u32 len)
Prepare binned data for per-bin computation.
sham::DeviceBuffer< Tret > device_histogram(const sham::DeviceScheduler_ptr &sched, const sham::DeviceBuffer< T > &bin_edges, u64 nbins, const sham::DeviceBuffer< T > &values, u32 len)
Compute the histogram of values between bin_edges.
sycl::buffer< T > scan_exclusive(sycl::queue &q, sycl::buffer< T > &buf1, u32 len)
Computes the exclusive sum of elements in a SYCL buffer.
std::tuple< std::optional< sycl::buffer< u32 > >, u32 > stream_compact(sycl::queue &q, sycl::buffer< u32 > &buf_flags, u32 len)
Stream compaction algorithm.
void throw_with_loc(std::string message, SourceLocation loc=SourceLocation{})
Throw an exception and append the source location to it.
constexpr T roundup_pow2(T v) noexcept
round up to the next power of two Source : https://graphics.stanford.edu/~seander/bithacks....
T & get_check_ref(const std::unique_ptr< T > &ptr, SourceLocation loc=SourceLocation())
Takes a std::unique_ptr and returns a reference to the object it holds. It throws a std::runtime_erro...
A class that references multiple buffers or similar objects.
Structure holding the result of binning values for further computation.