50 template<
class Tkey,
class Tval, u32 group_size, u32 digit_len>
53 template<
class Tkey,
class Tval, u32 group_size, u32 digit_len>
54 void sort_by_key_radix_onesweep(
55 sycl::queue &q, sycl::buffer<Tkey> &buf_key, sycl::buffer<Tval> &buf_values,
u32 len) {
57 sycl::buffer<Tkey> tmp_buf_key(len);
58 sycl::buffer<Tval> tmp_buf_values(len);
60 auto get_in_keys = [&](
u32 step) -> sycl::buffer<Tkey> & {
68 auto get_out_keys = [&](
u32 step) -> sycl::buffer<Tkey> & {
76 auto get_in_vals = [&](
u32 step) -> sycl::buffer<Tval> & {
80 return tmp_buf_values;
84 auto get_out_vals = [&](
u32 step) -> sycl::buffer<Tval> & {
86 return tmp_buf_values;
95 u32 corrected_len = group_cnt * group_size;
99 using Binner = DigitBinner<Tkey, digit_len>;
101 sycl::buffer<u32> digit_histogram
102 = Binner::template make_digit_histogram<group_size>(q, buf_key, len);
109 sycl::host_accessor acc{digit_histogram, sycl::read_write};
111 auto ptr = &(acc[0]);
113 for (
u32 digit_place = 0; digit_place < Binner::digit_bit_places; digit_place++) {
114 u32 offset_ptr = Binner::digit_count * digit_place;
116 ptr + offset_ptr, ptr + offset_ptr + Binner::digit_count, ptr + offset_ptr, 0);
123 using namespace shamalgs::numeric::details;
125 using DecoupledLookBack
129 for (Tkey cur_digit_place = 0; cur_digit_place < shambase::bitsizeof<Tkey>;
130 cur_digit_place += digit_len) {
132 DecoupledLookBack dlookbackscan(q, group_cnt, Binner::digit_count);
136 q.submit([&, len, cur_digit_place, step](sycl::handler &cgh) {
137 sycl::accessor keys{get_in_keys(step), cgh, sycl::read_only};
138 sycl::accessor vals{get_in_vals(step), cgh, sycl::read_only};
140 sycl::accessor new_keys{get_out_keys(step), cgh, sycl::write_only, sycl::no_init};
141 sycl::accessor new_vals{get_out_vals(step), cgh, sycl::write_only, sycl::no_init};
143 sycl::accessor value_write_offsets{digit_histogram, cgh, sycl::read_only};
145 sycl::local_accessor<u32, 1> local_digit_counts{Binner::digit_count, cgh};
146 sycl::local_accessor<u32, 1> scanned_digit_counts{Binner::digit_count, cgh};
149 auto dyn_id = id_gen.get_access(cgh);
151 auto scanop = dlookbackscan.get_access(cgh);
153 using at_ref_loc_count = sycl::atomic_ref<
155 sycl::memory_order_relaxed,
156 sycl::memory_scope_work_group,
157 sycl::access::address_space::local_space>;
159 u32 histogram_ptr_offset = step * Binner::digit_count;
161 cgh.parallel_for<SortByKeyRadixOnesweep<Tkey, Tval, group_size, digit_len>>(
162 sycl::nd_range<1>{corrected_len, group_size}, [=](sycl::nd_item<1> id) {
165 u32 local_id =
id.get_local_id(0);
166 u32 group_tile_id = group_id.dyn_group_id;
167 u32 global_id = group_id.dyn_global_id;
173 for (
u32 digit_ptr = 0; digit_ptr < Binner::digit_count; digit_ptr++) {
174 local_digit_counts[digit_ptr] = 0;
177 id.barrier(sycl::access::fence_space::local_space);
179 bool is_valid_key = (global_id < len);
181 Tkey cur_key = (is_valid_key) ? keys[global_id] : 0;
183 Tkey digit_value = Binner::get_digit_value(cur_key, step);
190 u32 curr_loc_offset = at_ref_loc_count(local_digit_counts[digit_value])
191 .fetch_add((is_valid_key) ? 1U : 0);
199 id.barrier(sycl::access::fence_space::local_space);
206 for (
u32 digit_ptr = 0; digit_ptr < Binner::digit_count; digit_ptr++) {
208 scanop.decoupled_lookback_scan(
213 return local_digit_counts[digit_ptr];
216 scanned_digit_counts[digit_ptr] = accum;
231 if (global_id < len) {
235 u32 value_write_offset_global
236 = value_write_offsets[(digit_value) + histogram_ptr_offset];
238 u32 write_offset = curr_loc_offset + scanned_digit_counts[digit_value]
239 + value_write_offset_global;
251 new_keys[write_offset]
254 new_vals[write_offset] = vals[global_id];
std::uint32_t u32
32 bit unsigned integer
Sycl utility to dynamically generate group ids.
Object returned by DynamicIdGenerator containing information about the worker affected id.
namespace to store algorithms implemented by shamalgs
constexpr u32 group_count(u32 len, u32 group_size)
Calculates the number of groups based on the length and group size.
main include file for memory algorithms