Shamrock 2025.10.0
Astrophysical Code
Loading...
Searching...
No Matches
dot_sum.cpp
Go to the documentation of this file.
1// -------------------------------------------------------//
2//
3// SHAMROCK code for hydrodynamics
4// Copyright (c) 2021-2026 Timothée David--Cléris <tim.shamrock@proton.me>
5// SPDX-License-Identifier: CeCILL Free Software License Agreement v2.1
6// Shamrock is licensed under the CeCILL 2.1 License, see LICENSE for more information
7//
8// -------------------------------------------------------//
9
19#include "shambackends/math.hpp"
20#include "shambackends/vec.hpp"
21#include <stdexcept>
22
23namespace shamalgs::primitives {
24
25 template<class T>
26 shambase::VecComponent<T> dot_sum(sham::DeviceBuffer<T> &buf1, u32 start_id, u32 end_id) {
27 using Tscal = shambase::VecComponent<T>;
28
29 if (start_id == end_id) {
30 return Tscal(0);
31 }
32
33 if (start_id > end_id) {
35 shambase::format("start_id > end_id : {} > {}", start_id, end_id));
36 }
37
38 sham::DeviceBuffer<Tscal> ret_data_base(end_id - start_id, buf1.get_dev_scheduler_ptr());
39
41 buf1.get_queue(),
42 sham::MultiRef{buf1},
43 sham::MultiRef{ret_data_base},
44 end_id - start_id,
45 [start_id](u32 i, const T *__restrict buf1, Tscal *__restrict out) {
46 T in = buf1[i + start_id];
47 out[i] = sham::dot(in, in);
48 });
49
51 buf1.get_dev_scheduler_ptr(), ret_data_base, 0, end_id - start_id);
52 }
53
54#ifndef DOXYGEN
55 #define XMAC_TYPES \
56 X(f32) \
57 X(f32_2) \
58 X(f32_3) \
59 X(f32_4) \
60 X(f32_8) \
61 X(f32_16) \
62 X(f64) \
63 X(f64_2) \
64 X(f64_3) \
65 X(f64_4) \
66 X(f64_8) \
67 X(f64_16) \
68 X(u32) \
69 X(u64) \
70 X(i32) \
71 X(i64) \
72 X(u32_3) \
73 X(u64_3) \
74 X(i64_3) \
75 X(i32_3)
76
77 #define X(_arg_) \
78 template shambase::VecComponent<_arg_> dot_sum( \
79 sham::DeviceBuffer<_arg_> &buf1, u32 start_id, u32 end_id);
80
81 XMAC_TYPES
82 #undef X
83#endif
84} // namespace shamalgs::primitives
std::uint32_t u32
32 bit unsigned integer
A buffer allocated in USM (Unified Shared Memory)
DeviceQueue & get_queue() const
Gets the DeviceQueue associated with the held allocation.
std::shared_ptr< DeviceScheduler > & get_dev_scheduler_ptr()
Gets the Device scheduler pointer corresponding to the held allocation.
void kernel_call(sham::DeviceQueue &q, RefIn in, RefOut in_out, u32 n, Functor &&func, SourceLocation &&callsite=SourceLocation{})
Submit a kernel to a SYCL queue.
namespace for primitive algorithm (e.g. sort, scan, reductions, ...)
T sum(const sham::DeviceScheduler_ptr &sched, const sham::DeviceBuffer< T > &buf1, u32 start_id, u32 end_id)
Compute the sum of elements in a device buffer within a specified range.
shambase::VecComponent< T > dot_sum(sham::DeviceBuffer< T > &buf1, u32 start_id, u32 end_id)
Compute the sum of dot products of elements in a device buffer with themselves.
Definition dot_sum.cpp:26
void throw_with_loc(std::string message, SourceLocation loc=SourceLocation{})
Throw an exception and append the source location to it.
A class that references multiple buffers or similar objects.