Shamrock 2025.10.0
Astrophysical Code
Loading...
Searching...
No Matches
is_all_true.cpp
Go to the documentation of this file.
1// -------------------------------------------------------//
2//
3// SHAMROCK code for hydrodynamics
4// Copyright (c) 2021-2026 Timothée David--Cléris <tim.shamrock@proton.me>
5// SPDX-License-Identifier: CeCILL Free Software License Agreement v2.1
6// Shamrock is licensed under the CeCILL 2.1 License, see LICENSE for more information
7//
8// -------------------------------------------------------//
9
18#include "shambase/memory.hpp"
21
22namespace {
23
24 template<class T>
25 bool is_all_true_host(sham::DeviceBuffer<T> &buf, u32 cnt) {
26
27 {
28 auto tmp = buf.copy_to_stdvec();
29
30 for (u32 i = 0; i < cnt; i++) {
31 if (tmp[i] == 0) {
32 return false;
33 }
34 }
35 }
36
37 return true;
38 }
39
40 template<class T>
41 bool is_all_true_sum_reduction(sham::DeviceBuffer<T> &buf, u32 cnt) {
42
43 if (cnt == 0) {
44 return true;
45 }
46
47 auto dev_sched = buf.get_dev_scheduler_ptr();
48
49 sham::DeviceBuffer<u32> tmp(cnt, dev_sched);
50
52 shambase::get_check_ref(dev_sched).get_queue(),
53 sham::MultiRef{buf},
54 sham::MultiRef{tmp},
55 cnt,
56 [](u32 i, const T *in, u32 *out) {
57 out[i] = in[i] != 0;
58 });
59
60 auto count_true = shamalgs::primitives::sum(dev_sched, tmp, 0, cnt);
61
62 return count_true == cnt;
63 }
64
65} // namespace
66
67namespace shamalgs::primitives {
68
69 enum class IS_ALL_TRUE_IMPL : u32 { HOST, SUM_REDUCTION };
70 IS_ALL_TRUE_IMPL is_all_true_impl = IS_ALL_TRUE_IMPL::HOST;
71
72 inline IS_ALL_TRUE_IMPL is_all_true_impl_from_params(const std::string &impl) {
73 if (impl == "host") {
74 return IS_ALL_TRUE_IMPL::HOST;
75 } else if (impl == "sum_reduction") {
76 return IS_ALL_TRUE_IMPL::SUM_REDUCTION;
77 }
79 "invalid implementation : {}, possible implementations : {}",
80 impl,
82 }
83
84 inline shamalgs::impl_param is_all_true_impl_to_params(const IS_ALL_TRUE_IMPL &impl) {
85 if (impl == IS_ALL_TRUE_IMPL::HOST) {
86 return {"host", ""};
87 } else if (impl == IS_ALL_TRUE_IMPL::SUM_REDUCTION) {
88 return {"sum_reduction", ""};
89 }
91 shambase::format("unknown is_all_true implementation : {}", u32(impl)));
92 }
93
94 std::vector<shamalgs::impl_param> impl::get_default_impl_list_is_all_true() {
95 std::vector<shamalgs::impl_param> impl_list{{"host", ""}, {"sum_reduction", ""}};
96 return impl_list;
97 }
98
99 void impl::set_impl_is_all_true(const std::string &impl, const std::string &param) {
100 shamlog_info_ln("tree", "setting is_all_true implementation to impl :", impl);
101 is_all_true_impl = is_all_true_impl_from_params(impl);
102 }
103
105 return is_all_true_impl_to_params(is_all_true_impl);
106 }
107
108 template<class T>
110 switch (is_all_true_impl) {
111 case IS_ALL_TRUE_IMPL::HOST : return is_all_true_host(buf, cnt);
112 case IS_ALL_TRUE_IMPL::SUM_REDUCTION: return is_all_true_sum_reduction(buf, cnt);
113 default:
115 shambase::format("unimplemented case : {}", u32(is_all_true_impl)));
116 }
117 }
118
119 template bool is_all_true(sham::DeviceBuffer<u8> &buf, u32 cnt);
120
121} // namespace shamalgs::primitives
122
123template<class T>
124bool shamalgs::primitives::is_all_true(sycl::buffer<T> &buf, u32 cnt) {
125
126 // TODO do it on GPU pleeeaze
127 {
128 sycl::host_accessor acc{buf, sycl::read_only};
129
130 for (u32 i = 0; i < cnt; i++) {
131 if (acc[i] == 0) {
132 return false;
133 }
134 }
135 }
136
137 return true;
138}
139
140template bool shamalgs::primitives::is_all_true(sycl::buffer<u8> &buf, u32 cnt);
std::uint32_t u32
32 bit unsigned integer
A buffer allocated in USM (Unified Shared Memory)
std::shared_ptr< DeviceScheduler > & get_dev_scheduler_ptr()
Gets the Device scheduler pointer corresponding to the held allocation.
std::vector< T > copy_to_stdvec() const
Copy the content of the buffer to a std::vector.
Boolean reduction algorithm for checking if all elements are non-zero.
void kernel_call(sham::DeviceQueue &q, RefIn in, RefOut in_out, u32 n, Functor &&func, SourceLocation &&callsite=SourceLocation{})
Submit a kernel to a SYCL queue.
void set_impl_is_all_true(const std::string &impl, const std::string &param="")
Set the implementation for is_all_true.
std::vector< shamalgs::impl_param > get_default_impl_list_is_all_true()
Get list of available is_all_true implementations.
shamalgs::impl_param get_current_impl_is_all_true()
Get the current implementation for is_all_true.
namespace for primitive algorithm (e.g. sort, scan, reductions, ...)
T sum(const sham::DeviceScheduler_ptr &sched, const sham::DeviceBuffer< T > &buf1, u32 start_id, u32 end_id)
Compute the sum of elements in a device buffer within a specified range.
bool is_all_true(sycl::buffer< T > &buf, u32 cnt)
Check if all elements in a sycl::buffer are non-zero.
void throw_with_loc(std::string message, SourceLocation loc=SourceLocation{})
Throw an exception and append the source location to it.
T & get_check_ref(const std::unique_ptr< T > &ptr, SourceLocation loc=SourceLocation())
Takes a std::unique_ptr and returns a reference to the object it holds. It throws a std::runtime_erro...
Definition memory.hpp:110
A class that references multiple buffers or similar objects.