Shamrock 2025.10.0
Astrophysical Code
Loading...
Searching...
No Matches
Field.hpp
Go to the documentation of this file.
1// -------------------------------------------------------//
2//
3// SHAMROCK code for hydrodynamics
4// Copyright (c) 2021-2026 Timothée David--Cléris <tim.shamrock@proton.me>
5// SPDX-License-Identifier: CeCILL Free Software License Agreement v2.1
6// Shamrock is licensed under the CeCILL 2.1 License, see LICENSE for more information
7//
8// -------------------------------------------------------//
9
10#pragma once
11
23
24namespace shamrock::solvergraph {
25
26 template<class T>
27 class Field : public IFieldRefs<T> {
28
29 // TODO In the long run this class should become what was compute field
30
31 u32 nvar;
32 std::string name;
33 ComputeField<T> field;
34
35 DDPatchDataFieldRef<T> field_refs;
36
38
39 void sync() {
40 field_refs = field.field_data.template map<std::reference_wrapper<PatchDataField<T>>>(
41 [&](u64 id, PatchDataField<T> &pdf) {
42 return std::ref(pdf);
43 });
44 spans = field_refs.template map<shamrock::PatchDataFieldSpanPointer<T>>(
45 [&](u64 id, std::reference_wrapper<PatchDataField<T>> &pdf) {
46 return pdf.get().get_pointer_span();
47 });
48 }
49
50 public:
51 Field(u32 nvar, std::string name, std::string texsymbol)
52 : nvar(nvar), name(name), IFieldRefs<T>(name, texsymbol) {}
53
54 virtual DDPatchDataFieldRef<T> &get_refs() { return field_refs; }
55
56 virtual const DDPatchDataFieldRef<T> &get_refs() const { return field_refs; }
57
58 virtual DDPatchDataFieldSpanPointer<T> &get_spans() { return spans; }
59
60 virtual const DDPatchDataFieldSpanPointer<T> &get_spans() const { return spans; }
61
62 shambase::DistributedData<u32> get_obj_cnts() const { return field.get_obj_cnts(); }
63
64 inline virtual void check_sizes(const shambase::DistributedData<u32> &sizes) const {
65 on_distributeddata_diff(
66 field.field_data,
67 sizes,
68 [&](u64 id) {
69 shambase::throw_with_loc<std::runtime_error>(shambase::format(
70 "Missing field ref in distributed data at id {}\n"
71 "Field name: {}\n"
72 "Field texsymbol: {}",
73 id,
74 this->get_label(),
75 this->get_tex_symbol()));
76 },
77 [](u64 id) {
78 // TODO
79 },
80 [&](u64 id) {
81 shambase::throw_with_loc<std::runtime_error>(shambase::format(
82 "Extra field ref in distributed data at id {}\n"
83 "Field name: {}\n"
84 "Field texsymbol: {}",
85 id,
86 this->get_label(),
87 this->get_tex_symbol()));
88 });
89 }
90
91 // overload only the non const case
92 inline virtual void ensure_sizes(const shambase::DistributedData<u32> &sizes) {
93
94 auto new_patchdatafield = [&](u32 size) {
95 auto ret = PatchDataField<T>(name, nvar);
96 ret.resize(size);
97 return ret;
98 };
99
100 auto ensure_patchdatafield_sizes = [&](u32 size, auto &pdatfield) {
101 if (pdatfield.get_obj_cnt() != size) {
102 pdatfield.resize(size);
103 }
104 };
105
106 on_distributeddata_diff(
107 field.field_data,
108 sizes,
109 [&](u64 id) {
110 field.field_data.add_obj(id, new_patchdatafield(sizes.get(id)));
111 },
112 [&](u64 id) {
113 ensure_patchdatafield_sizes(sizes.get(id), field.field_data.get(id));
114 },
115 [&](u64 id) {
116 field.field_data.erase(id);
117 });
118
119 sync();
120 }
121
122 inline virtual void free_alloc() { field.field_data = {}; }
123
124 inline ComputeField<T> extract() { return std::exchange(field, {}); }
125
126 inline sham::DeviceBuffer<T> &get_buf(u64 id_patch) {
127 return field.field_data.get(id_patch).get_buf();
128 }
129
130 inline PatchDataField<T> &get(u64 id_patch) { return field.field_data.get(id_patch); }
131 inline const PatchDataField<T> &get(u64 id_patch) const {
132 return field.field_data.get(id_patch);
133 }
134
135 inline u32 get_nvar() const { return nvar; }
136 };
137} // namespace shamrock::solvergraph
std::uint32_t u32
32 bit unsigned integer
std::uint64_t u64
64 bit unsigned integer
A buffer allocated in USM (Unified Shared Memory)
Represents a collection of objects distributed across patches identified by a u64 id.
T & get(u64 id)
Returns a reference to an object in the collection.
virtual const DDPatchDataFieldRef< T > & get_refs() const
Const variant of get_refs.
Definition Field.hpp:56
virtual DDPatchDataFieldRef< T > & get_refs()
Get the DistributedData of PatchDataFieldRefs.
Definition Field.hpp:54
virtual DDPatchDataFieldSpanPointer< T > & get_spans()
Get the DistributedData of spans attached to the underlying field.
Definition Field.hpp:58
virtual void free_alloc()
Free allocated memory.
Definition Field.hpp:122
virtual const DDPatchDataFieldSpanPointer< T > & get_spans() const
Const variant of get_spans.
Definition Field.hpp:60
virtual void check_sizes(const shambase::DistributedData< u32 > &sizes) const
Check that the sizes of the patches in the field match the given sizes.
Definition Field.hpp:64
virtual void ensure_sizes(const shambase::DistributedData< u32 > &sizes)
Ensure that the sizes of the patches in the field match the given sizes (Can resize the underlying fi...
Definition Field.hpp:92
Interface for a solver graph edge representing a field as references to the underlying patch fields.