Shamrock 2025.10.0
Astrophysical Code
Loading...
Searching...
No Matches
segmented_sort_in_place.cpp
Go to the documentation of this file.
1// -------------------------------------------------------//
2//
3// SHAMROCK code for hydrodynamics
4// Copyright (c) 2021-2026 Timothée David--Cléris <tim.shamrock@proton.me>
5// SPDX-License-Identifier: CeCILL Free Software License Agreement v2.1
6// Shamrock is licensed under the CeCILL 2.1 License, see LICENSE for more information
7//
8// -------------------------------------------------------//
9
18#include "shambase/assert.hpp"
21
22namespace shamalgs::primitives::details {
23
24 template<class T, class Comp>
25 inline void segmented_sort_in_place_local_insertion_sort(
26 sham::DeviceBuffer<T> &buf, const sham::DeviceBuffer<u32> &offsets, Comp &&comp) {
27
28 auto &q = buf.get_dev_scheduler().get_queue();
29
30 size_t interact_count = buf.get_size();
31 size_t offsets_count = offsets.get_size();
32 size_t N = offsets_count - 1;
33
35 q,
36 sham::MultiRef{offsets},
37 sham::MultiRef{buf},
38 N,
39 [interact_count,
40 comp](u32 gid, const u32 *__restrict__ offsets, T *__restrict__ in_out_sorted) {
41 u32 start_index = offsets[gid];
42 u32 end_index = offsets[gid + 1];
43
44 // can be equal if there is no interaction for this sender
45 SHAM_ASSERT(start_index <= end_index);
46
47 // skip empty ranges to avoid unnecessary work
48 if (start_index == end_index) {
49 return;
50 }
51
52 // if there is no interactions at the end of the offset list
53 // offsets[gid] can be equal to interact_count
54 // but we check that start_index != end_index, so here the correct assertions
55 // is indeed start_index < interact_count
56 SHAM_ASSERT(start_index < interact_count);
57 SHAM_ASSERT(end_index <= interact_count); // see the for loop for this one
58
59 shambase::ptr_insert_sort(in_out_sorted, start_index, end_index, comp);
60 });
61 }
62
63 template<class T, class Comp>
64 inline void segmented_sort_in_place_multi_std_sort(
65 sham::DeviceBuffer<T> &buf, const sham::DeviceBuffer<u32> &offsets, Comp &&comp) {
66
67 auto &q = buf.get_dev_scheduler().get_queue();
68
69 size_t interact_count = buf.get_size();
70 size_t offsets_count = offsets.get_size();
71 size_t N = offsets_count - 1;
72
73 std::vector<T> buf_stdvec = buf.copy_to_stdvec();
74 std::vector<u32> offsets_stdvec = offsets.copy_to_stdvec();
75
76#pragma omp parallel for
77 for (u32 i = 0; i < N; ++i) {
78 u32 start_index = offsets_stdvec[i];
79 u32 end_index = offsets_stdvec[i + 1];
80
81 // can be equal if there is no interaction for this sender
82 SHAM_ASSERT(start_index <= end_index);
83
84 // skip empty ranges to avoid unnecessary work
85 if (start_index == end_index) {
86 continue;
87 }
88
89 // if there is no interactions at the end of the offset list
90 // offsets[gid] can be equal to interact_count
91 // but we check that start_index != end_index, so here the correct assertions
92 // is indeed start_index < interact_count
93 SHAM_ASSERT(start_index < interact_count);
94 SHAM_ASSERT(end_index <= interact_count); // see the for loop for this one
95
96 std::sort(buf_stdvec.begin() + start_index, buf_stdvec.begin() + end_index, comp);
97 }
98
99 buf.copy_from_stdvec(buf_stdvec);
100 }
101
102} // namespace shamalgs::primitives::details
103
104namespace shamalgs::primitives {
105
107 namespace impl {
108
109 enum class SEGMENTED_SORT_IN_PLACE_IMPL : u32 {
110 LOCAL_INSERTION_SORT,
111 MULTI_STD_SORT,
112 };
113
114 SEGMENTED_SORT_IN_PLACE_IMPL get_default_segmented_sort_in_place_impl() {
115 return SEGMENTED_SORT_IN_PLACE_IMPL::MULTI_STD_SORT;
116 }
117
118 SEGMENTED_SORT_IN_PLACE_IMPL segmented_sort_in_place_impl
119 = get_default_segmented_sort_in_place_impl();
120
121 inline SEGMENTED_SORT_IN_PLACE_IMPL segmented_sort_in_place_impl_from_params(
122 const std::string &impl) {
123 if (impl == "local_insertion_sort") {
124 return SEGMENTED_SORT_IN_PLACE_IMPL::LOCAL_INSERTION_SORT;
125 } else if (impl == "multi_std_sort") {
126 return SEGMENTED_SORT_IN_PLACE_IMPL::MULTI_STD_SORT;
127 }
129 "invalid implementation : {}, possible implementations : {}",
130 impl,
132 }
133
134 inline shamalgs::impl_param segmented_sort_in_place_impl_to_params(
135 const SEGMENTED_SORT_IN_PLACE_IMPL &impl) {
136 if (impl == SEGMENTED_SORT_IN_PLACE_IMPL::LOCAL_INSERTION_SORT) {
137 return {"local_insertion_sort", ""};
138 } else if (impl == SEGMENTED_SORT_IN_PLACE_IMPL::MULTI_STD_SORT) {
139 return {"multi_std_sort", ""};
140 }
142 shambase::format("unknown segmented sort in place implementation : {}", u32(impl)));
143 }
144
146 std::vector<shamalgs::impl_param> get_default_impl_list_segmented_sort_in_place() {
147 return {
148 {"local_insertion_sort", ""},
149 {"multi_std_sort", ""},
150 };
151 }
152
155 return segmented_sort_in_place_impl_to_params(segmented_sort_in_place_impl);
156 }
157
159 void set_impl_segmented_sort_in_place(const std::string &impl, const std::string &param) {
160 shamlog_info_ln(
161 "tree", "setting segmented sort in place implementation to impl :", impl);
162 segmented_sort_in_place_impl = segmented_sort_in_place_impl_from_params(impl);
163 }
164
165 } // namespace impl
166
167 template<class T, class Comp>
168 void internal_segmented_sort_in_place(
169 sham::DeviceBuffer<T> &buf, const sham::DeviceBuffer<u32> &offsets, Comp &&comp) {
170
171 if (buf.get_size() == 0) {
172 return;
173 }
174
175 if (offsets.get_size() == 0) {
176 throw shambase::make_except_with_loc<std::invalid_argument>("offsets buffer is empty");
177 }
178
179 switch (impl::segmented_sort_in_place_impl) {
180 case impl::SEGMENTED_SORT_IN_PLACE_IMPL::LOCAL_INSERTION_SORT:
181 details::segmented_sort_in_place_local_insertion_sort(buf, offsets, comp);
182 break;
183
184 case impl::SEGMENTED_SORT_IN_PLACE_IMPL::MULTI_STD_SORT:
185 details::segmented_sort_in_place_multi_std_sort(buf, offsets, comp);
186 break;
187 default:
189 "unimplemented case : {}", u32(impl::segmented_sort_in_place_impl)));
190 }
191 }
192
193 template<>
194 void segmented_sort_in_place<u32_2>(
196
197 internal_segmented_sort_in_place(buf, offsets, [](u32_2 a, u32_2 b) {
198 return (a.x() == b.x()) ? (a.y() < b.y()) : (a.x() < b.x());
199 });
200 }
201
202 template<>
203 void segmented_sort_in_place<u32>(
205 internal_segmented_sort_in_place(buf, offsets, [](u32 a, u32 b) {
206 return a < b;
207 });
208 }
209
210} // namespace shamalgs::primitives
std::uint32_t u32
32 bit unsigned integer
Shamrock assertion utility.
#define SHAM_ASSERT(x)
Shorthand for SHAM_ASSERT_NAMED without a message.
Definition assert.hpp:67
A buffer allocated in USM (Unified Shared Memory)
void copy_from_stdvec(const std::vector< T > &vec)
Copy the content of a std::vector into the buffer.
std::vector< T > copy_to_stdvec() const
Copy the content of the buffer to a std::vector.
size_t get_size() const
Gets the number of elements in the buffer.
DeviceScheduler & get_dev_scheduler() const
Gets the Device scheduler corresponding to the held allocation.
DeviceQueue & get_queue(u32 id=0)
Get a reference to a DeviceQueue.
void kernel_call(sham::DeviceQueue &q, RefIn in, RefOut in_out, u32 n, Functor &&func, SourceLocation &&callsite=SourceLocation{})
Submit a kernel to a SYCL queue.
std::vector< shamalgs::impl_param > get_default_impl_list_segmented_sort_in_place()
Get list of available segmented sort in place implementations.
void set_impl_segmented_sort_in_place(const std::string &impl, const std::string &param="")
Set the implementation for segmented sort in place.
shamalgs::impl_param get_current_impl_segmented_sort_in_place()
Get the current implementation for segmented sort in place.
namespace for primitive algorithm (e.g. sort, scan, reductions, ...)
void throw_with_loc(std::string message, SourceLocation loc=SourceLocation{})
Throw an exception and append the source location to it.
void ptr_insert_sort(T *data, u32 start, u32 end, Comp &&comp)
Simple insertion sort on pointer range.
A class that references multiple buffers or similar objects.