Shamrock 2025.10.0
Astrophysical Code
Loading...
Searching...
No Matches
equals.hpp
Go to the documentation of this file.
1// -------------------------------------------------------//
2//
3// SHAMROCK code for hydrodynamics
4// Copyright (c) 2021-2026 Timothée David--Cléris <tim.shamrock@proton.me>
5// SPDX-License-Identifier: CeCILL Free Software License Agreement v2.1
6// Shamrock is licensed under the CeCILL 2.1 License, see LICENSE for more information
7//
8// -------------------------------------------------------//
9
10#pragma once
11
32#include "shambackends/math.hpp"
33#include "shambackends/sycl.hpp"
35
36namespace shamalgs::primitives {
37
76 template<class T>
77 bool equals(sycl::queue &q, sycl::buffer<T> &buf1, sycl::buffer<T> &buf2, u32 cnt) {
78
79 if (buf1.size() < cnt) {
80 throw shambase::make_except_with_loc<std::invalid_argument>("buf 1 is larger than cnt");
81 }
82
83 if (buf2.size() < cnt) {
84 throw shambase::make_except_with_loc<std::invalid_argument>("buf 2 is larger than cnt");
85 }
86
87 sham::DeviceBuffer<u8> res(cnt, shamsys::instance::get_compute_scheduler_ptr());
88
89 sham::EventList deps;
90 auto out = res.get_write_access(deps);
91
92 deps.set_consumed(true);
93 auto e = q.submit([&](sycl::handler &cgh) {
94 cgh.depends_on(deps.get_events());
95 sycl::accessor acc1{buf1, cgh, sycl::read_only};
96 sycl::accessor acc2{buf2, cgh, sycl::read_only};
97
98 cgh.parallel_for(sycl::range{cnt}, [=](sycl::item<1> item) {
99 out[item] = sham::equals(acc1[item], acc2[item]);
100 });
101 });
102
104
105 return shamalgs::primitives::is_all_true(res, cnt);
106 }
107
153 template<class T>
154 inline bool equals(
155 const sham::DeviceScheduler_ptr &dev_sched,
158 u32 cnt) {
159
160 // kernel call does not support 0 elements
161 if (cnt == 0) {
162 return true;
163 }
164
165 // if the buffers are the same early return true
166 if (&buf1 == &buf2) {
167 return true;
168 }
169
170 if (buf1.get_size() < cnt) {
171 throw shambase::make_except_with_loc<std::invalid_argument>("buf 1 is larger than cnt");
172 }
173
174 if (buf2.get_size() < cnt) {
175 throw shambase::make_except_with_loc<std::invalid_argument>("buf 2 is larger than cnt");
176 }
177
178 sham::DeviceBuffer<u8> res(cnt, dev_sched);
179
180 auto &q = shambase::get_check_ref(dev_sched).get_queue();
181
183 q,
184 sham::MultiRef{buf1, buf2},
185 sham::MultiRef{res},
186 cnt,
187 [](u32 i, const T *__restrict acc1, const T *__restrict acc2, u8 *__restrict out) {
188 out[i] = sham::equals(acc1[i], acc2[i]);
189 });
190
191 return shamalgs::primitives::is_all_true(res, cnt);
192 }
193
234 template<class T>
235 inline bool equals(
236 const sham::DeviceScheduler_ptr &q,
238 sham::DeviceBuffer<T> &buf2) {
239
240 bool same_size = buf1.get_size() == buf2.get_size();
241 if (!same_size) {
242 return false;
243 }
244
245 return equals(q, buf1, buf2, buf1.get_size());
246 }
247
279 template<class T>
280 bool equals(sycl::queue &q, sycl::buffer<T> &buf1, sycl::buffer<T> &buf2) {
281 bool same_size = buf1.size() == buf2.size();
282 if (!same_size) {
283 return false;
284 }
285
286 return equals(q, buf1, buf2, buf1.size());
287 }
288
328 template<class T>
330 sycl::queue &q,
331 const std::unique_ptr<sycl::buffer<T>> &buf1,
332 const std::unique_ptr<sycl::buffer<T>> &buf2,
333 u32 cnt) {
334 bool same_alloc = bool(buf1) == bool(buf2);
335
336 if (!same_alloc) {
337 return false;
338 }
339
340 if (!bool(buf1)) {
341 return true;
342 }
343
344 return equals(q, *buf1, *buf2, cnt);
345 }
346
385 template<class T>
387 sycl::queue &q,
388 const std::unique_ptr<sycl::buffer<T>> &buf1,
389 const std::unique_ptr<sycl::buffer<T>> &buf2) {
390 bool same_alloc = bool(buf1) == bool(buf2);
391
392 if (!same_alloc) {
393 return false;
394 }
395
396 if (!bool(buf1)) {
397 return true;
398 }
399
400 return equals(q, *buf1, *buf2);
401 }
402} // namespace shamalgs::primitives
Header file describing a Node Instance.
std::uint8_t u8
8 bit unsigned integer
std::uint32_t u32
32 bit unsigned integer
A buffer allocated in USM (Unified Shared Memory)
void complete_event_state(sycl::event e) const
Complete the event state of the buffer.
T * get_write_access(sham::EventList &depends_list, SourceLocation src_loc=SourceLocation{})
Get a read-write pointer to the buffer's data.
size_t get_size() const
Gets the number of elements in the buffer.
Class to manage a list of SYCL events.
Definition EventList.hpp:31
void set_consumed(bool consumed)
Set the consumed state of the EventList (to be used with interop)
std::vector< sycl::event > & get_events()
Get the list of events.
This header file contains utility functions related to exception handling in the code.
Boolean reduction algorithm for checking if all elements are non-zero.
void kernel_call(sham::DeviceQueue &q, RefIn in, RefOut in_out, u32 n, Functor &&func, SourceLocation &&callsite=SourceLocation{})
Submit a kernel to a SYCL queue.
namespace for primitive algorithm (e.g. sort, scan, reductions, ...)
bool equals(sycl::queue &q, sycl::buffer< T > &buf1, sycl::buffer< T > &buf2, u32 cnt)
Compare elements between two sycl::buffers for equality.
Definition equals.hpp:77
bool is_all_true(sycl::buffer< T > &buf, u32 cnt)
Check if all elements in a sycl::buffer are non-zero.
bool equals_ptr(sycl::queue &q, const std::unique_ptr< sycl::buffer< T > > &buf1, const std::unique_ptr< sycl::buffer< T > > &buf2)
Compare all elements between two unique_ptr-wrapped sycl::buffers.
Definition equals.hpp:386
bool equals_ptr_s(sycl::queue &q, const std::unique_ptr< sycl::buffer< T > > &buf1, const std::unique_ptr< sycl::buffer< T > > &buf2, u32 cnt)
Compare elements between two unique_ptr-wrapped sycl::buffers with count.
Definition equals.hpp:329
void throw_with_loc(std::string message, SourceLocation loc=SourceLocation{})
Throw an exception and append the source location to it.
T & get_check_ref(const std::unique_ptr< T > &ptr, SourceLocation loc=SourceLocation())
Takes a std::unique_ptr and returns a reference to the object it holds. It throws a std::runtime_erro...
Definition memory.hpp:110
A class that references multiple buffers or similar objects.