29namespace shamalgs::reduction::details {
31 template<
class T,
class GroupCombiner,
class IdentityGetter>
32 inline sycl::event reduc_step(
40 GroupCombiner &&group_combine,
41 IdentityGetter &&identity_getter) {
45 auto e = q.
submit(depends_list, [&](sycl::handler &cgh) {
46 u32 slice_read_size = cur_slice_sz;
47 u32 slice_write_size = cur_slice_sz * work_group_size;
50 cgh.parallel_for(exec_range, [=](sycl::nd_item<1> item) {
51 u64 lid = item.get_local_id(0);
52 u64 group_tile_id = item.get_group_linear_id();
53 u64 gid = group_tile_id * work_group_size + lid;
55 u64 iread = gid * slice_read_size;
56 u64 iwrite = group_tile_id * slice_write_size;
58 T val_read = (iread < max_id) ? global_mem[iread] : identity_getter();
60 T local_red = group_combine(item.get_group(), val_read);
64 global_mem[iwrite] = local_red;
69 cur_slice_sz *= work_group_size;
70 remaining_val = exec_range.get_group_range().size();
75 template<
class T,
class GroupCombiner,
class BinaryOp,
class IdentityGetter>
76 inline T reduc_internal(
77 const sham::DeviceScheduler_ptr &sched,
82 GroupCombiner &&group_combine,
84 IdentityGetter &&identity_getter) {
88 if (start_id >= end_id) {
90 "Empty (or invalid) range not supported for reduction operation");
93 u32 len = end_id - start_id;
100 T *compute_buf = buf_int.get_write_access(depends_list);
102 u32 cur_slice_sz = 1;
103 u32 remaining_val = len;
104 while (len / cur_slice_sz > work_group_size * 8) {
105 auto e = reduc_step<T>(
113 std::forward<GroupCombiner>(group_combine),
114 std::forward<IdentityGetter>(identity_getter));
117 std::swap(depends_list, old_list);
122 T *result = recov_buf.get_write_access(depends_list);
125 auto e = q.
submit(depends_list, [&, remaining_val](sycl::handler &cgh) {
126 u32 slice_read_size = cur_slice_sz;
128 cgh.parallel_for(exec_range, [=](sycl::nd_item<1> item) {
129 u64 lid = item.get_local_id(0);
130 u64 group_tile_id = item.get_group_linear_id();
131 u64 gid = group_tile_id * work_group_size + lid;
133 u64 iread = gid * slice_read_size;
135 if (gid >= remaining_val) {
139 result[gid] = compute_buf[iread];
143 buf_int.complete_event_state(e);
144 recov_buf.complete_event_state(e);
146 auto acc = recov_buf.copy_to_stdvec();
148 for (
u64 i = 1; i < remaining_val; i++) {
149 ret = binary_op(ret, acc[i]);
std::uint32_t u32
32 bit unsigned integer
std::uint64_t u64
64 bit unsigned integer
A buffer allocated in USM (Unified Shared Memory)
void copy_range(size_t begin, size_t end, sham::DeviceBuffer< T, dest_target > &dest) const
Copy a range of elements from the buffer to another buffer.
A SYCL queue associated with a device and a context.
sycl::event submit(Fct &&fct)
Submits a kernel to the SYCL queue.
Class to manage a list of SYCL events.
void add_event(sycl::event e)
Add an event to the list of events.
This header file contains utility functions related to exception handling in the code.
Define the fmt formatters for sycl::vec.
void throw_with_loc(std::string message, SourceLocation loc=SourceLocation{})
Throw an exception and append the source location to it.
sycl::nd_range< 1 > make_range(u32 length, const u32 group_size=32)
Generate a sycl nd range out of a group size and length.
T & get_check_ref(const std::unique_ptr< T > &ptr, SourceLocation loc=SourceLocation())
Takes a std::unique_ptr and returns a reference to the object it holds. It throws a std::runtime_erro...
main include file for memory algorithms