30 for (
u32 i = 0; i < cnt; i++) {
56 [](
u32 i,
const T *in,
u32 *out) {
62 return count_true == cnt;
69 enum class IS_ALL_TRUE_IMPL :
u32 { HOST, SUM_REDUCTION };
70 IS_ALL_TRUE_IMPL is_all_true_impl = IS_ALL_TRUE_IMPL::HOST;
72 inline IS_ALL_TRUE_IMPL is_all_true_impl_from_params(
const std::string &impl) {
74 return IS_ALL_TRUE_IMPL::HOST;
75 }
else if (impl ==
"sum_reduction") {
76 return IS_ALL_TRUE_IMPL::SUM_REDUCTION;
79 "invalid implementation : {}, possible implementations : {}",
85 if (impl == IS_ALL_TRUE_IMPL::HOST) {
87 }
else if (impl == IS_ALL_TRUE_IMPL::SUM_REDUCTION) {
88 return {
"sum_reduction",
""};
91 shambase::format(
"unknown is_all_true implementation : {}",
u32(impl)));
95 std::vector<shamalgs::impl_param> impl_list{{
"host",
""}, {
"sum_reduction",
""}};
100 shamlog_info_ln(
"tree",
"setting is_all_true implementation to impl :", impl);
101 is_all_true_impl = is_all_true_impl_from_params(impl);
105 return is_all_true_impl_to_params(is_all_true_impl);
110 switch (is_all_true_impl) {
111 case IS_ALL_TRUE_IMPL::HOST :
return is_all_true_host(buf, cnt);
112 case IS_ALL_TRUE_IMPL::SUM_REDUCTION:
return is_all_true_sum_reduction(buf, cnt);
115 shambase::format(
"unimplemented case : {}",
u32(is_all_true_impl)));
128 sycl::host_accessor acc{buf, sycl::read_only};
130 for (
u32 i = 0; i < cnt; i++) {
std::uint32_t u32
32 bit unsigned integer
A buffer allocated in USM (Unified Shared Memory)
std::shared_ptr< DeviceScheduler > & get_dev_scheduler_ptr()
Gets the Device scheduler pointer corresponding to the held allocation.
std::vector< T > copy_to_stdvec() const
Copy the content of the buffer to a std::vector.
Boolean reduction algorithm for checking if all elements are non-zero.
void kernel_call(sham::DeviceQueue &q, RefIn in, RefOut in_out, u32 n, Functor &&func, SourceLocation &&callsite=SourceLocation{})
Submit a kernel to a SYCL queue.
void set_impl_is_all_true(const std::string &impl, const std::string ¶m="")
Set the implementation for is_all_true.
std::vector< shamalgs::impl_param > get_default_impl_list_is_all_true()
Get list of available is_all_true implementations.
shamalgs::impl_param get_current_impl_is_all_true()
Get the current implementation for is_all_true.
namespace for primitive algorithm (e.g. sort, scan, reductions, ...)
T sum(const sham::DeviceScheduler_ptr &sched, const sham::DeviceBuffer< T > &buf1, u32 start_id, u32 end_id)
Compute the sum of elements in a device buffer within a specified range.
bool is_all_true(sycl::buffer< T > &buf, u32 cnt)
Check if all elements in a sycl::buffer are non-zero.
void throw_with_loc(std::string message, SourceLocation loc=SourceLocation{})
Throw an exception and append the source location to it.
T & get_check_ref(const std::unique_ptr< T > &ptr, SourceLocation loc=SourceLocation())
Takes a std::unique_ptr and returns a reference to the object it holds. It throws a std::runtime_erro...
A class that references multiple buffers or similar objects.