Shamrock 2025.10.0
Astrophysical Code
Loading...
Searching...
No Matches
DistributedBuffers.hpp
Go to the documentation of this file.
1// -------------------------------------------------------//
2//
3// SHAMROCK code for hydrodynamics
4// Copyright (c) 2021-2026 Timothée David--Cléris <tim.shamrock@proton.me>
5// SPDX-License-Identifier: CeCILL Free Software License Agreement v2.1
6// Shamrock is licensed under the CeCILL 2.1 License, see LICENSE for more information
7//
8// -------------------------------------------------------//
9
10#pragma once
11
24
25namespace shamrock::solvergraph {
26
28 template<class T>
30
38 template<class T>
40 public:
41 using IEdgeNamed::IEdgeNamed;
42
43 DDDeviceBuffer<T> buffers;
44
45 inline virtual void free_alloc() { buffers = {}; }
46
47 inline virtual void check_allocated(const std::vector<u64> &ids) const {
48 on_distributeddata_ids_diff(
49 buffers,
50 ids,
51 [](u64 id) {
53 "Missing buffer in distributed data at id " + std::to_string(id));
54 },
55 [](u64 id) {},
56 [](u64 id) {
58 "Extra buffer in distributed data at id " + std::to_string(id));
59 });
60 }
61
62 // overload only the non const case
63 inline virtual void ensure_allocated(const std::vector<u64> &ids) {
64
65 auto new_buf = [&]() {
66 auto ret = sham::DeviceBuffer<T>(0, shamsys::instance::get_compute_scheduler_ptr());
67 return ret;
68 };
69
71 buffers,
72 ids,
73 [&](u64 id) {
74 buffers.add_obj(id, new_buf());
75 },
76 [](u64 id) {
77 // Nothing for now
78 },
79 [&](u64 id) {
80 buffers.erase(id);
81 });
82 }
83 };
84
85} // namespace shamrock::solvergraph
Header file describing a Node Instance.
std::uint64_t u64
64 bit unsigned integer
A buffer allocated in USM (Unified Shared Memory)
Represents a collection of objects distributed across patches identified by a u64 id.
iterator add_obj(u64 id, T &&obj)
Adds a new object to the collection.
void erase(u64 id)
Removes an object from the collection.
Interface for a solver graph edge representing a field as spans.
virtual void free_alloc()
Free allocated memory.
void throw_with_loc(std::string message, SourceLocation loc=SourceLocation{})
Throw an exception and append the source location to it.
void on_distributeddata_ids_diff(const shambase::DistributedData< T1 > &dd, const std::vector< u64 > &ref_ids, FuncMatch &&func_missing, FuncMissing &&func_match, FuncExtra &&func_extra)
Compare two distributed data and apply callbacks based on the difference.