31#include <pybind11/complex.h>
35 py::module shamalgs_module = m.def_submodule(
"algs",
"algorithmic library");
37 py::class_<std::mt19937>(shamalgs_module,
"rng");
39 py::class_<shamalgs::impl_param>(shamalgs_module,
"impl_param")
45 &shamalgs::impl_param::impl_name,
46 py::return_value_policy::reference_internal)
48 "params", &shamalgs::impl_param::params, py::return_value_policy::reference_internal)
52 return shambase::format(
53 "impl_param(impl_name=\"{}\", params=\"{}\")",
58 return shambase::format(
59 "impl_param(impl_name=\"{}\", params=\"{}\")",
64 shamalgs_module.def(
"gen_seed", [](
u64 seed) {
65 return std::mt19937(seed);
68 shamalgs_module.def(
"mock_gaussian", [](std::mt19937 &eng) {
69 return shamalgs::random::mock_gaussian<f64>(eng);
71 shamalgs_module.def(
"mock_gaussian_f64_2", [](std::mt19937 &eng) {
72 return shamalgs::random::mock_gaussian_multidim<f64_2>(eng);
74 shamalgs_module.def(
"mock_gaussian_f64_3", [](std::mt19937 &eng) {
75 return shamalgs::random::mock_gaussian_multidim<f64_3>(eng);
77 shamalgs_module.def(
"mock_unit_vector_f64_3", [](std::mt19937 &eng) {
78 return shamalgs::random::mock_unit_vector<f64_3>(eng);
81 shamalgs_module.def(
"mock_buffer_f64", [](
u64 seed,
u32 len,
f64 min_bound,
f64 max_bound) {
82 return shamalgs::random::mock_buffer_usm<f64>(
83 shamsys::instance::get_compute_scheduler_ptr(), seed, len, min_bound, max_bound);
85 shamalgs_module.def(
"mock_buffer_u8", [](
u64 seed,
u32 len,
u8 min_bound,
u8 max_bound) {
86 return shamalgs::random::mock_buffer_usm<u8>(
87 shamsys::instance::get_compute_scheduler_ptr(), seed, len, min_bound, max_bound);
89 shamalgs_module.def(
"mock_buffer_u32", [](
u64 seed,
u32 len,
u32 min_bound,
u32 max_bound) {
90 return shamalgs::random::mock_buffer_usm<u32>(
91 shamsys::instance::get_compute_scheduler_ptr(), seed, len, min_bound, max_bound);
94 "mock_buffer_f64_2", [](
u64 seed,
u32 len, f64_2 min_bound, f64_2 max_bound) {
95 return shamalgs::random::mock_buffer_usm<f64_2>(
96 shamsys::instance::get_compute_scheduler_ptr(), seed, len, min_bound, max_bound);
99 "mock_buffer_f64_3", [](
u64 seed,
u32 len, f64_3 min_bound, f64_3 max_bound) {
100 return shamalgs::random::mock_buffer_usm<f64_3>(
101 shamsys::instance::get_compute_scheduler_ptr(), seed, len, min_bound, max_bound);
121 "set_impl_is_all_true", [](
const std::string &impl,
const std::string ¶m =
"") {
125 shamalgs_module.def(
"get_current_impl_is_all_true", []() {
129 shamalgs_module.def(
"get_default_impl_list_is_all_true", []() {
137 shamsys::instance::get_compute_scheduler_ptr(), buf, start_id, end_id);
145 shamsys::instance::get_compute_scheduler_ptr(), buf, 0, len);
155 shamsys::instance::get_compute_scheduler_ptr(), buf, 0, len);
161 "set_impl_reduction", [](
const std::string &impl,
const std::string ¶m =
"") {
165 shamalgs_module.def(
"get_current_impl_reduction", []() {
169 shamalgs_module.def(
"get_default_impl_list_reduction", []() {
193 "set_impl_scan_exclusive_sum_in_place",
194 [](
const std::string &impl,
const std::string ¶m =
"") {
198 shamalgs_module.def(
"get_current_impl_scan_exclusive_sum_in_place", []() {
202 shamalgs_module.def(
"get_default_impl_list_scan_exclusive_sum_in_place", []() {
209 "segmented_sort_in_place",
211 shamalgs::primitives::segmented_sort_in_place(buf, offsets);
215 "benchmark_segmented_sort_in_place",
217 auto buf_copy = buf.
copy();
218 auto offsets_copy = offsets.
copy();
221 offsets_copy.synchronize();
226 shamalgs::primitives::segmented_sort_in_place(buf_copy, offsets_copy);
227 buf_copy.synchronize();
228 offsets_copy.synchronize();
235 "set_impl_segmented_sort_in_place",
236 [](
const std::string &impl,
const std::string ¶m =
"") {
240 shamalgs_module.def(
"get_current_impl_segmented_sort_in_place", []() {
244 shamalgs_module.def(
"get_default_impl_list_segmented_sort_in_place", []() {
249 py::class_<shamalgs::primitives::ImplControl>(shamalgs_module,
"ImplControl")
253 return impl_control.get_alg_name();
258 return impl_control.was_configured(shamsys::instance::get_compute_scheduler_ptr());
263 return impl_control.get_config(shamsys::instance::get_compute_scheduler_ptr());
268 impl_control.set_config(shamsys::instance::get_compute_scheduler_ptr(), config);
271 "get_default_config",
273 return impl_control.get_default_config(
274 shamsys::instance::get_compute_scheduler_ptr());
277 return impl_control.get_avail_configs(shamsys::instance::get_compute_scheduler_ptr());
281 "compute_histogram_impl",
283 return shamalgs::primitives::impl::compute_histogram_impl_control;
285 py::return_value_policy::reference);
288 "compute_histogram_basic_f64",
292 return shamalgs::primitives::compute_histogram_basic<f64>(
293 shamsys::instance::get_compute_scheduler_ptr(),
299 "compute_histogram_basic_f32",
303 return shamalgs::primitives::compute_histogram_basic<f32>(
304 shamsys::instance::get_compute_scheduler_ptr(),
311 "benchmark_compute_histogram_basic_f64",
320 auto result = shamalgs::primitives::compute_histogram_basic<f64>(
321 shamsys::instance::get_compute_scheduler_ptr(),
325 result.synchronize();
333 "benchmark_compute_histogram_basic_f32",
342 auto result = shamalgs::primitives::compute_histogram_basic<f32>(
343 shamsys::instance::get_compute_scheduler_ptr(),
347 result.synchronize();
Header file describing a Node Instance.
double f64
Alias for double.
float f32
Alias for float.
std::uint8_t u8
8 bit unsigned integer
std::uint32_t u32
32 bit unsigned integer
std::uint64_t u64
64 bit unsigned integer
A buffer allocated in USM (Unified Shared Memory)
void synchronize() const
Wait for all the events associated with the buffer to be completed.
DeviceBuffer< T, target > copy() const
Copy the current buffer.
Class Timer measures the time elapsed since the timer was started.
void end()
Stops the timer and stores the elapsed time in nanoseconds.
f64 elasped_sec() const
Converts the stored nanosecond time to a floating point representation in seconds.
void start()
Starts the timer.
Boolean reduction algorithm for checking if all elements are non-zero.
std::vector< shamalgs::impl_param > get_default_impl_list_segmented_sort_in_place()
Get list of available segmented sort in place implementations.
void set_impl_reduction(const std::string &impl, const std::string ¶m="")
Set the implementation for reduction.
std::vector< shamalgs::impl_param > get_default_impl_list_reduction()
Get list of available reduction implementations.
std::vector< shamalgs::impl_param > get_default_impl_list_scan_exclusive_sum_in_place()
Get list of available scan_exclusive_sum_in_place implementations.
void set_impl_segmented_sort_in_place(const std::string &impl, const std::string ¶m="")
Set the implementation for segmented sort in place.
void set_impl_is_all_true(const std::string &impl, const std::string ¶m="")
Set the implementation for is_all_true.
shamalgs::impl_param get_current_impl_scan_exclusive_sum_in_place()
Get the current implementation for scan_exclusive_sum_in_place.
shamalgs::impl_param get_current_impl_segmented_sort_in_place()
Get the current implementation for segmented sort in place.
shamalgs::impl_param get_current_impl_reduction()
Get the current implementation for reduction.
void set_impl_scan_exclusive_sum_in_place(const std::string &impl, const std::string ¶m="")
Set the implementation for scan_exclusive_sum_in_place.
std::vector< shamalgs::impl_param > get_default_impl_list_is_all_true()
Get list of available is_all_true implementations.
shamalgs::impl_param get_current_impl_is_all_true()
Get the current implementation for is_all_true.
T sum(const sham::DeviceScheduler_ptr &sched, const sham::DeviceBuffer< T > &buf1, u32 start_id, u32 end_id)
Compute the sum of elements in a device buffer within a specified range.
bool is_all_true(sycl::buffer< T > &buf, u32 cnt)
Check if all elements in a sycl::buffer are non-zero.
void scan_exclusive_sum_in_place(sham::DeviceBuffer< T > &buf1, u32 len)
Compute exclusive prefix sum in-place on a device buffer.
f64 timeitfor(Func &&f, f64 max_duration=1)
Measures the average time it takes to execute a function until a maximum duration is reached.
Pybind11 include and definitions.
#define Register_pymod(placeholdername)
Register a python module init function using static initialisation.
In-place exclusive scan (prefix sum) algorithm for device buffers.