Shamrock 2025.10.0
Astrophysical Code
Loading...
Searching...
No Matches
groupReduction_usm_impl.hpp
Go to the documentation of this file.
1// -------------------------------------------------------//
2//
3// SHAMROCK code for hydrodynamics
4// Copyright (c) 2021-2026 Timothée David--Cléris <tim.shamrock@proton.me>
5// SPDX-License-Identifier: CeCILL Free Software License Agreement v2.1
6// Shamrock is licensed under the CeCILL 2.1 License, see LICENSE for more information
7//
8// -------------------------------------------------------//
9
10#pragma once
11
20#include "shambase/memory.hpp"
22#include "shamalgs/memory.hpp"
24#include "shambackends/math.hpp"
25#include "shambackends/sycl.hpp"
27#include "shambackends/vec.hpp"
28
29namespace shamalgs::reduction::details {
30
31 template<class T, class GroupCombiner, class IdentityGetter>
32 inline sycl::event reduc_step(
34 T *global_mem,
35 sham::EventList &depends_list,
36 u32 len,
37 u32 &cur_slice_sz,
38 u32 &remaining_val,
39 u32 work_group_size,
40 GroupCombiner &&group_combine,
41 IdentityGetter &&identity_getter) {
42
43 sycl::nd_range<1> exec_range = shambase::make_range(remaining_val, work_group_size);
44
45 auto e = q.submit(depends_list, [&](sycl::handler &cgh) {
46 u32 slice_read_size = cur_slice_sz;
47 u32 slice_write_size = cur_slice_sz * work_group_size;
48 u32 max_id = len;
49
50 cgh.parallel_for(exec_range, [=](sycl::nd_item<1> item) {
51 u64 lid = item.get_local_id(0);
52 u64 group_tile_id = item.get_group_linear_id();
53 u64 gid = group_tile_id * work_group_size + lid;
54
55 u64 iread = gid * slice_read_size;
56 u64 iwrite = group_tile_id * slice_write_size;
57
58 T val_read = (iread < max_id) ? global_mem[iread] : identity_getter();
59
60 T local_red = group_combine(item.get_group(), val_read);
61
62 // can be removed if i change the index in the look back ?
63 if (lid == 0) {
64 global_mem[iwrite] = local_red;
65 }
66 });
67 });
68
69 cur_slice_sz *= work_group_size;
70 remaining_val = exec_range.get_group_range().size();
71
72 return e;
73 }
74
75 template<class T, class GroupCombiner, class BinaryOp, class IdentityGetter>
76 inline T reduc_internal(
77 const sham::DeviceScheduler_ptr &sched,
78 const sham::DeviceBuffer<T> &buf1,
79 u32 start_id,
80 u32 end_id,
81 u32 work_group_size,
82 GroupCombiner &&group_combine,
83 BinaryOp &&binary_op,
84 IdentityGetter &&identity_getter) {
85
86 sham::DeviceQueue &q = shambase::get_check_ref(sched).get_queue();
87
88 if (start_id >= end_id) {
90 "Empty (or invalid) range not supported for reduction operation");
91 }
92
93 u32 len = end_id - start_id;
94
95 sham::DeviceBuffer<T> buf_int(len, sched);
96
97 buf1.copy_range(start_id, end_id, buf_int);
98
99 sham::EventList depends_list;
100 T *compute_buf = buf_int.get_write_access(depends_list);
101
102 u32 cur_slice_sz = 1;
103 u32 remaining_val = len;
104 while (len / cur_slice_sz > work_group_size * 8) {
105 auto e = reduc_step<T>(
106 q,
107 compute_buf,
108 depends_list,
109 len,
110 cur_slice_sz,
111 remaining_val,
112 work_group_size,
113 std::forward<GroupCombiner>(group_combine),
114 std::forward<IdentityGetter>(identity_getter));
115
116 sham::EventList old_list;
117 std::swap(depends_list, old_list);
118 depends_list.add_event(e);
119 }
120
121 sham::DeviceBuffer<T> recov_buf(remaining_val, sched);
122 T *result = recov_buf.get_write_access(depends_list);
123
124 sycl::nd_range<1> exec_range = shambase::make_range(remaining_val, work_group_size);
125 auto e = q.submit(depends_list, [&, remaining_val](sycl::handler &cgh) {
126 u32 slice_read_size = cur_slice_sz;
127
128 cgh.parallel_for(exec_range, [=](sycl::nd_item<1> item) {
129 u64 lid = item.get_local_id(0);
130 u64 group_tile_id = item.get_group_linear_id();
131 u64 gid = group_tile_id * work_group_size + lid;
132
133 u64 iread = gid * slice_read_size;
134
135 if (gid >= remaining_val) {
136 return;
137 }
138
139 result[gid] = compute_buf[iread];
140 });
141 });
142
143 buf_int.complete_event_state(e);
144 recov_buf.complete_event_state(e);
145
146 auto acc = recov_buf.copy_to_stdvec();
147 T ret = acc[0]; // init value
148 for (u64 i = 1; i < remaining_val; i++) {
149 ret = binary_op(ret, acc[i]);
150 }
151
152 return ret;
153 }
154} // namespace shamalgs::reduction::details
std::uint32_t u32
32 bit unsigned integer
std::uint64_t u64
64 bit unsigned integer
A buffer allocated in USM (Unified Shared Memory)
void copy_range(size_t begin, size_t end, sham::DeviceBuffer< T, dest_target > &dest) const
Copy a range of elements from the buffer to another buffer.
A SYCL queue associated with a device and a context.
sycl::event submit(Fct &&fct)
Submits a kernel to the SYCL queue.
Class to manage a list of SYCL events.
Definition EventList.hpp:31
void add_event(sycl::event e)
Add an event to the list of events.
Definition EventList.hpp:87
This header file contains utility functions related to exception handling in the code.
Define the fmt formatters for sycl::vec.
void throw_with_loc(std::string message, SourceLocation loc=SourceLocation{})
Throw an exception and append the source location to it.
sycl::nd_range< 1 > make_range(u32 length, const u32 group_size=32)
Generate a sycl nd range out of a group size and length.
T & get_check_ref(const std::unique_ptr< T > &ptr, SourceLocation loc=SourceLocation())
Takes a std::unique_ptr and returns a reference to the object it holds. It throws a std::runtime_erro...
Definition memory.hpp:110
main include file for memory algorithms