Shamrock 2025.10.0
Astrophysical Code
Loading...
Searching...
No Matches
pyShamalgs.cpp
Go to the documentation of this file.
1// -------------------------------------------------------//
2//
3// SHAMROCK code for hydrodynamics
4// Copyright (c) 2021-2026 Timothée David--Cléris <tim.shamrock@proton.me>
5// SPDX-License-Identifier: CeCILL Free Software License Agreement v2.1
6// Shamrock is licensed under the CeCILL 2.1 License, see LICENSE for more information
7//
8// -------------------------------------------------------//
9
17#include "shambase/time.hpp"
25#include "shamalgs/random.hpp"
29#include "shamcomm/logs.hpp"
31#include <pybind11/complex.h>
32
33Register_pymod(shamalgslibinit) {
34
35 py::module shamalgs_module = m.def_submodule("algs", "algorithmic library");
36
37 py::class_<std::mt19937>(shamalgs_module, "rng");
38
39 py::class_<shamalgs::impl_param>(shamalgs_module, "impl_param")
40 .def(py::init([]() {
41 return shamalgs::impl_param{"", ""};
42 }))
43 .def_readwrite(
44 "impl_name",
45 &shamalgs::impl_param::impl_name,
46 py::return_value_policy::reference_internal)
47 .def_readwrite(
48 "params", &shamalgs::impl_param::params, py::return_value_policy::reference_internal)
49 .def(
50 "__str__",
51 [](const shamalgs::impl_param &impl_param) {
52 return shambase::format(
53 "impl_param(impl_name=\"{}\", params=\"{}\")",
54 impl_param.impl_name,
55 impl_param.params);
56 })
57 .def("__repr__", [](const shamalgs::impl_param &impl_param) {
58 return shambase::format(
59 "impl_param(impl_name=\"{}\", params=\"{}\")",
60 impl_param.impl_name,
61 impl_param.params);
62 });
63
64 shamalgs_module.def("gen_seed", [](u64 seed) {
65 return std::mt19937(seed);
66 });
67
68 shamalgs_module.def("mock_gaussian", [](std::mt19937 &eng) {
69 return shamalgs::random::mock_gaussian<f64>(eng);
70 });
71 shamalgs_module.def("mock_gaussian_f64_2", [](std::mt19937 &eng) {
72 return shamalgs::random::mock_gaussian_multidim<f64_2>(eng);
73 });
74 shamalgs_module.def("mock_gaussian_f64_3", [](std::mt19937 &eng) {
75 return shamalgs::random::mock_gaussian_multidim<f64_3>(eng);
76 });
77 shamalgs_module.def("mock_unit_vector_f64_3", [](std::mt19937 &eng) {
78 return shamalgs::random::mock_unit_vector<f64_3>(eng);
79 });
80
81 shamalgs_module.def("mock_buffer_f64", [](u64 seed, u32 len, f64 min_bound, f64 max_bound) {
82 return shamalgs::random::mock_buffer_usm<f64>(
83 shamsys::instance::get_compute_scheduler_ptr(), seed, len, min_bound, max_bound);
84 });
85 shamalgs_module.def("mock_buffer_u8", [](u64 seed, u32 len, u8 min_bound, u8 max_bound) {
86 return shamalgs::random::mock_buffer_usm<u8>(
87 shamsys::instance::get_compute_scheduler_ptr(), seed, len, min_bound, max_bound);
88 });
89 shamalgs_module.def("mock_buffer_u32", [](u64 seed, u32 len, u32 min_bound, u32 max_bound) {
90 return shamalgs::random::mock_buffer_usm<u32>(
91 shamsys::instance::get_compute_scheduler_ptr(), seed, len, min_bound, max_bound);
92 });
93 shamalgs_module.def(
94 "mock_buffer_f64_2", [](u64 seed, u32 len, f64_2 min_bound, f64_2 max_bound) {
95 return shamalgs::random::mock_buffer_usm<f64_2>(
96 shamsys::instance::get_compute_scheduler_ptr(), seed, len, min_bound, max_bound);
97 });
98 shamalgs_module.def(
99 "mock_buffer_f64_3", [](u64 seed, u32 len, f64_3 min_bound, f64_3 max_bound) {
100 return shamalgs::random::mock_buffer_usm<f64_3>(
101 shamsys::instance::get_compute_scheduler_ptr(), seed, len, min_bound, max_bound);
102 });
103
104 { // is_all_true
105
106 shamalgs_module.def("is_all_true", [](sham::DeviceBuffer<u8> &buf, u32 len) {
107 return shamalgs::primitives::is_all_true(buf, len);
108 });
109
110 shamalgs_module.def("benchmark_is_all_true", [](sham::DeviceBuffer<u8> &buf, u32 len) {
111 buf.synchronize();
112 shambase::Timer timer;
113 timer.start();
114 bool result = shamalgs::primitives::is_all_true(buf, len);
115 buf.synchronize();
116 timer.end();
117 return timer.elasped_sec();
118 });
119
120 shamalgs_module.def(
121 "set_impl_is_all_true", [](const std::string &impl, const std::string &param = "") {
123 });
124
125 shamalgs_module.def("get_current_impl_is_all_true", []() {
127 });
128
129 shamalgs_module.def("get_default_impl_list_is_all_true", []() {
131 });
132 }
133
134 { // reductions
135 shamalgs_module.def("sum", [](sham::DeviceBuffer<f64> &buf, u32 start_id, u32 end_id) {
137 shamsys::instance::get_compute_scheduler_ptr(), buf, start_id, end_id);
138 });
139
140 shamalgs_module.def("benchmark_reduction_sum", [](sham::DeviceBuffer<f64> &buf, u32 len) {
141 buf.synchronize();
142 shambase::Timer timer;
143 timer.start();
145 shamsys::instance::get_compute_scheduler_ptr(), buf, 0, len);
146 timer.end();
147 return timer.elasped_sec();
148 });
149
150 shamalgs_module.def("benchmark_reduction_sum", [](sham::DeviceBuffer<f32> &buf, u32 len) {
151 buf.synchronize();
152 shambase::Timer timer;
153 timer.start();
155 shamsys::instance::get_compute_scheduler_ptr(), buf, 0, len);
156 timer.end();
157 return timer.elasped_sec();
158 });
159
160 shamalgs_module.def(
161 "set_impl_reduction", [](const std::string &impl, const std::string &param = "") {
163 });
164
165 shamalgs_module.def("get_current_impl_reduction", []() {
167 });
168
169 shamalgs_module.def("get_default_impl_list_reduction", []() {
171 });
172 }
173
174 { // scan_exclusive_sum_in_place
175
176 shamalgs_module.def(
177 "scan_exclusive_sum_in_place", [](sham::DeviceBuffer<u32> &buf, u32 len) {
179 });
180
181 shamalgs_module.def(
182 "benchmark_scan_exclusive_sum_in_place", [](sham::DeviceBuffer<u32> &buf, u32 len) {
183 buf.synchronize();
184 shambase::Timer timer;
185 timer.start();
187 buf.synchronize();
188 timer.end();
189 return timer.elasped_sec();
190 });
191
192 shamalgs_module.def(
193 "set_impl_scan_exclusive_sum_in_place",
194 [](const std::string &impl, const std::string &param = "") {
196 });
197
198 shamalgs_module.def("get_current_impl_scan_exclusive_sum_in_place", []() {
200 });
201
202 shamalgs_module.def("get_default_impl_list_scan_exclusive_sum_in_place", []() {
204 });
205 }
206
207 { // segmented_sort_in_place
208 shamalgs_module.def(
209 "segmented_sort_in_place",
210 [](sham::DeviceBuffer<u32> &buf, const sham::DeviceBuffer<u32> &offsets) {
211 shamalgs::primitives::segmented_sort_in_place(buf, offsets);
212 });
213
214 shamalgs_module.def(
215 "benchmark_segmented_sort_in_place",
216 [](sham::DeviceBuffer<u32> &buf, const sham::DeviceBuffer<u32> &offsets) {
217 auto buf_copy = buf.copy();
218 auto offsets_copy = offsets.copy();
219
220 buf_copy.synchronize();
221 offsets_copy.synchronize();
222
223 shambase::Timer timer;
224 timer.start();
225
226 shamalgs::primitives::segmented_sort_in_place(buf_copy, offsets_copy);
227 buf_copy.synchronize();
228 offsets_copy.synchronize();
229
230 timer.end();
231 return timer.elasped_sec();
232 });
233
234 shamalgs_module.def(
235 "set_impl_segmented_sort_in_place",
236 [](const std::string &impl, const std::string &param = "") {
238 });
239
240 shamalgs_module.def("get_current_impl_segmented_sort_in_place", []() {
242 });
243
244 shamalgs_module.def("get_default_impl_list_segmented_sort_in_place", []() {
246 });
247 }
248
249 py::class_<shamalgs::primitives::ImplControl>(shamalgs_module, "ImplControl")
250 .def(
251 "get_alg_name",
252 [](shamalgs::primitives::ImplControl &impl_control) {
253 return impl_control.get_alg_name();
254 })
255 .def(
256 "was_configured",
257 [](shamalgs::primitives::ImplControl &impl_control) {
258 return impl_control.was_configured(shamsys::instance::get_compute_scheduler_ptr());
259 })
260 .def(
261 "get_config",
262 [](shamalgs::primitives::ImplControl &impl_control) {
263 return impl_control.get_config(shamsys::instance::get_compute_scheduler_ptr());
264 })
265 .def(
266 "set_config",
267 [](shamalgs::primitives::ImplControl &impl_control, const std::string &config) {
268 impl_control.set_config(shamsys::instance::get_compute_scheduler_ptr(), config);
269 })
270 .def(
271 "get_default_config",
272 [](shamalgs::primitives::ImplControl &impl_control) {
273 return impl_control.get_default_config(
274 shamsys::instance::get_compute_scheduler_ptr());
275 })
276 .def("get_avail_configs", [](shamalgs::primitives::ImplControl &impl_control) {
277 return impl_control.get_avail_configs(shamsys::instance::get_compute_scheduler_ptr());
278 });
279
280 shamalgs_module.def(
281 "compute_histogram_impl",
283 return shamalgs::primitives::impl::compute_histogram_impl_control;
284 },
285 py::return_value_policy::reference);
286
287 shamalgs_module.def(
288 "compute_histogram_basic_f64",
289 [](sham::DeviceBuffer<f64> &bin_edge_inf,
290 sham::DeviceBuffer<f64> &bin_edge_sup,
291 sham::DeviceBuffer<f64> &positions) {
292 return shamalgs::primitives::compute_histogram_basic<f64>(
293 shamsys::instance::get_compute_scheduler_ptr(),
294 bin_edge_inf,
295 bin_edge_sup,
296 positions);
297 });
298 shamalgs_module.def(
299 "compute_histogram_basic_f32",
300 [](sham::DeviceBuffer<f32> &bin_edge_inf,
301 sham::DeviceBuffer<f32> &bin_edge_sup,
302 sham::DeviceBuffer<f32> &positions) {
303 return shamalgs::primitives::compute_histogram_basic<f32>(
304 shamsys::instance::get_compute_scheduler_ptr(),
305 bin_edge_inf,
306 bin_edge_sup,
307 positions);
308 });
309
310 shamalgs_module.def(
311 "benchmark_compute_histogram_basic_f64",
312 [](sham::DeviceBuffer<f64> &bin_edge_inf,
313 sham::DeviceBuffer<f64> &bin_edge_sup,
314 sham::DeviceBuffer<f64> &positions) {
315 bin_edge_inf.synchronize();
316 bin_edge_sup.synchronize();
317 positions.synchronize();
318
319 auto run = [&]() {
320 auto result = shamalgs::primitives::compute_histogram_basic<f64>(
321 shamsys::instance::get_compute_scheduler_ptr(),
322 bin_edge_inf,
323 bin_edge_sup,
324 positions);
325 result.synchronize();
326 };
327
328 run();
329
330 return shambase::timeitfor(run);
331 });
332 shamalgs_module.def(
333 "benchmark_compute_histogram_basic_f32",
334 [](sham::DeviceBuffer<f32> &bin_edge_inf,
335 sham::DeviceBuffer<f32> &bin_edge_sup,
336 sham::DeviceBuffer<f32> &positions) {
337 bin_edge_inf.synchronize();
338 bin_edge_sup.synchronize();
339 positions.synchronize();
340
341 auto run = [&]() {
342 auto result = shamalgs::primitives::compute_histogram_basic<f32>(
343 shamsys::instance::get_compute_scheduler_ptr(),
344 bin_edge_inf,
345 bin_edge_sup,
346 positions);
347 result.synchronize();
348 };
349
350 run();
351
352 return shambase::timeitfor(run);
353 });
354}
Header file describing a Node Instance.
double f64
Alias for double.
float f32
Alias for float.
std::uint8_t u8
8 bit unsigned integer
std::uint32_t u32
32 bit unsigned integer
std::uint64_t u64
64 bit unsigned integer
A buffer allocated in USM (Unified Shared Memory)
void synchronize() const
Wait for all the events associated with the buffer to be completed.
DeviceBuffer< T, target > copy() const
Copy the current buffer.
Class Timer measures the time elapsed since the timer was started.
Definition time.hpp:96
void end()
Stops the timer and stores the elapsed time in nanoseconds.
Definition time.hpp:111
f64 elasped_sec() const
Converts the stored nanosecond time to a floating point representation in seconds.
Definition time.hpp:123
void start()
Starts the timer.
Definition time.hpp:106
Boolean reduction algorithm for checking if all elements are non-zero.
std::vector< shamalgs::impl_param > get_default_impl_list_segmented_sort_in_place()
Get list of available segmented sort in place implementations.
void set_impl_reduction(const std::string &impl, const std::string &param="")
Set the implementation for reduction.
Definition reduction.cpp:99
std::vector< shamalgs::impl_param > get_default_impl_list_reduction()
Get list of available reduction implementations.
Definition reduction.cpp:84
std::vector< shamalgs::impl_param > get_default_impl_list_scan_exclusive_sum_in_place()
Get list of available scan_exclusive_sum_in_place implementations.
void set_impl_segmented_sort_in_place(const std::string &impl, const std::string &param="")
Set the implementation for segmented sort in place.
void set_impl_is_all_true(const std::string &impl, const std::string &param="")
Set the implementation for is_all_true.
shamalgs::impl_param get_current_impl_scan_exclusive_sum_in_place()
Get the current implementation for scan_exclusive_sum_in_place.
shamalgs::impl_param get_current_impl_segmented_sort_in_place()
Get the current implementation for segmented sort in place.
shamalgs::impl_param get_current_impl_reduction()
Get the current implementation for reduction.
Definition reduction.cpp:95
void set_impl_scan_exclusive_sum_in_place(const std::string &impl, const std::string &param="")
Set the implementation for scan_exclusive_sum_in_place.
std::vector< shamalgs::impl_param > get_default_impl_list_is_all_true()
Get list of available is_all_true implementations.
shamalgs::impl_param get_current_impl_is_all_true()
Get the current implementation for is_all_true.
T sum(const sham::DeviceScheduler_ptr &sched, const sham::DeviceBuffer< T > &buf1, u32 start_id, u32 end_id)
Compute the sum of elements in a device buffer within a specified range.
bool is_all_true(sycl::buffer< T > &buf, u32 cnt)
Check if all elements in a sycl::buffer are non-zero.
void scan_exclusive_sum_in_place(sham::DeviceBuffer< T > &buf1, u32 len)
Compute exclusive prefix sum in-place on a device buffer.
f64 timeitfor(Func &&f, f64 max_duration=1)
Measures the average time it takes to execute a function until a maximum duration is reached.
Definition time.hpp:207
Pybind11 include and definitions.
#define Register_pymod(placeholdername)
Register a python module init function using static initialisation.
In-place exclusive scan (prefix sum) algorithm for device buffers.