Shamrock 2025.10.0
Astrophysical Code
Loading...
Searching...
No Matches
RenderFieldGetter.cpp
Go to the documentation of this file.
1// -------------------------------------------------------//
2//
3// SHAMROCK code for hydrodynamics
4// Copyright (c) 2021-2026 Timothée David--Cléris <tim.shamrock@proton.me>
5// SPDX-License-Identifier: CeCILL Free Software License Agreement v2.1
6// Shamrock is licensed under the CeCILL 2.1 License, see LICENSE for more information
7//
8// -------------------------------------------------------//
9
27#include <string>
28
30
31 template<class Tvec, class Tfield, template<class> class SPHKernel>
32 shamrock::solvergraph::Field<Tfield> RenderFieldGetter<Tvec, Tfield, SPHKernel>::build_field(
33 std::string field_name,
34 std::optional<std::function<py::array_t<Tfield>(size_t, pybind11::dict &)>> custom_getter) {
35
36 if (field_name != "custom" && custom_getter.has_value()) {
38 "custom_getter is only supported for the custom field");
39 }
40
42
43 using namespace shamrock;
44 using namespace shamrock::patch;
45
46 scheduler().for_each_patchdata_nonempty([&](const Patch p, PatchDataLayer &pdat) {
47 sizes.add_obj(p.id_patch, pdat.get_obj_cnt());
48 });
49
50 auto make_field = [&](u32 nvar, std::string name, std::string texsymbol) {
52 = shamrock::solvergraph::Field<Tfield>(nvar, name, texsymbol);
53 ret.ensure_sizes(sizes);
54 return ret;
55 };
56
57 if constexpr (std::is_same_v<Tfield, f64>) {
58 if (field_name == "rho" && std::is_same_v<Tscal, Tfield>) {
59
60 auto density = make_field(1, "rho", "rho");
61
62 scheduler().for_each_patchdata_nonempty([&](const Patch p, PatchDataLayer &pdat) {
63 shamlog_debug_ln("sph::vtk", "compute rho field for patch ", p.id_patch);
64
65 auto &buf_h
66 = pdat.get_field<Tscal>(pdat.pdl().get_field_idx<Tscal>("hpart")).get_buf();
67 auto &buf_rho = density.get_buf(p.id_patch);
68
69 sham::DeviceQueue &q = shamsys::instance::get_compute_scheduler().get_queue();
70
71 sham::EventList depends_list;
72
73 auto acc_h = buf_h.get_read_access(depends_list);
74 auto acc_rho = buf_rho.get_write_access(depends_list);
75
76 auto e = q.submit(depends_list, [&](sycl::handler &cgh) {
77 const Tscal part_mass = solver_config.gpart_mass;
78
79 cgh.parallel_for(
80 sycl::range<1>{pdat.get_obj_cnt()}, [=](sycl::item<1> item) {
81 u32 gid = (u32) item.get_id();
82 using namespace shamrock::sph;
83 Tscal rho_ha = rho_h(part_mass, acc_h[gid], Kernel::hfactd);
84 acc_rho[gid] = rho_ha;
85 });
86 });
87
88 buf_h.complete_event_state(e);
89 buf_rho.complete_event_state(e);
90 });
91
92 return density;
93 } else if (field_name == "inv_hpart" && std::is_same_v<Tscal, Tfield>) {
94
95 auto inv_hpart = make_field(1, "inv_hpart", "inv_hpart");
96
97 scheduler().for_each_patchdata_nonempty([&](const Patch p, PatchDataLayer &pdat) {
98 shamlog_debug_ln("sph::vtk", "compute inv_hpart field for patch ", p.id_patch);
99
100 auto &buf_h
101 = pdat.get_field<Tscal>(pdat.pdl().get_field_idx<Tscal>("hpart")).get_buf();
102 auto &buf_inv_hpart = inv_hpart.get_buf(p.id_patch);
103
104 sham::DeviceQueue &q = shamsys::instance::get_compute_scheduler().get_queue();
105
106 sham::EventList depends_list;
107
108 auto acc_h = buf_h.get_read_access(depends_list);
109 auto acc_inv_hpart = buf_inv_hpart.get_write_access(depends_list);
110
111 auto e = q.submit(depends_list, [&](sycl::handler &cgh) {
112 cgh.parallel_for(
113 sycl::range<1>{pdat.get_obj_cnt()}, [=](sycl::item<1> item) {
114 u32 gid = (u32) item.get_id();
115 using namespace shamrock::sph;
116 acc_inv_hpart[gid] = 1.0 / acc_h[gid];
117 });
118 });
119
120 buf_h.complete_event_state(e);
121 buf_inv_hpart.complete_event_state(e);
122 });
123
124 return inv_hpart;
125 } else if (field_name == "unity" && std::is_same_v<Tscal, Tfield>) {
126 using namespace shamrock;
127 using namespace shamrock::patch;
128
129 auto unity = make_field(1, "unity", "unity");
130 sizes.for_each([&](u64 id_patch, u32 size) {
131 unity.get_buf(id_patch).fill(1);
132 });
133
134 return unity;
135 } else if (field_name == "custom" && custom_getter.has_value()) {
136 std::function<py::array_t<Tfield>(size_t, pybind11::dict &)> &field_source_getter
137 = custom_getter.value();
138
139 auto custom = make_field(1, "custom", "custom");
140
141 shambase::Timer timer;
142 timer.start();
143
144 scheduler().for_each_patchdata_nonempty([&](const Patch p, PatchDataLayer &pdat) {
145 shamlog_debug_ln("sph::vtk", "compute custom field for patch ", p.id_patch);
146
147 auto &buf_custom = custom.get_buf(p.id_patch);
148
149 sham::DeviceQueue &q = shamsys::instance::get_compute_scheduler().get_queue();
150
151 py::dict dic_out = shamrock::pdat_to_dic(pdat);
152 std::vector<Tfield> acc_custom = buf_custom.copy_to_stdvec();
153
154 py::array_t<Tfield> custom_array
155 = field_source_getter(pdat.get_obj_cnt(), dic_out);
156
157 if (acc_custom.size() != custom_array.size()) {
159 "custom_array size does not match the number of particles");
160 }
161
162 acc_custom = custom_array.template cast<std::vector<Tfield>>();
163
164 buf_custom.copy_from_stdvec(acc_custom);
165 });
166
167 timer.end();
168
169 f64 worse_time_rank = shamalgs::collective::allreduce_max(timer.elasped_sec());
170
171 if (shamcomm::world_rank() == 0) {
172 logger::raw_ln(
173 "sph::RenderFieldGetter",
174 "compute custom field took : ",
175 worse_time_rank,
176 "s");
177 }
178
179 return custom;
180 }
181 }
182
183 auto field_source_getter
184 = [&](const shamrock::patch::Patch cur_p,
186 return pdat.get_field<Tfield>(pdat.pdl().get_field_idx<Tfield>(field_name)).get_buf();
187 };
188
189 FieldDescriptor<Tfield> desc = scheduler().pdl_old().template get_field<Tfield>(field_name);
190 u32 ifield = scheduler().pdl_old().template get_field_idx<Tfield>(field_name);
191
192 if (desc.nvar > 1) {
193 shambase::throw_unimplemented("this cannot handle cases with nvar > 1, yet ...");
194 }
195
196 auto ret = make_field(1, desc.name, desc.name);
197
198 scheduler().for_each_patchdata_nonempty([&](const Patch p, PatchDataLayer &pdat) {
199 sham::DeviceBuffer<Tfield> &buf = ret.get_buf(p.id_patch);
200 buf.copy_from(
201 pdat.get_field<Tfield>(pdat.pdl().get_field_idx<Tfield>(field_name)).get_buf());
202 });
203
204 return ret;
205 }
206
207 template<class Tvec, class Tfield, template<class> class SPHKernel>
208 auto RenderFieldGetter<Tvec, Tfield, SPHKernel>::runner_function(
209 std::string field_name,
210 lamda_runner lambda,
211 std::optional<std::function<py::array_t<Tfield>(size_t, pybind11::dict &)>> custom_getter)
213
214 auto field = build_field(field_name, custom_getter);
215
216 auto field_source_getter
217 = [&](const shamrock::patch::Patch cur_p,
219 return field.get_buf(cur_p.id_patch);
220 };
221
222 return lambda(field_source_getter);
223 }
224} // namespace shammodels::sph::modules
225
226using namespace shammath;
230
234
238
constexpr const char * density
Density \rho (derived from h)
constexpr const char * sizes
Temporary sizes for h-iteration.
double f64
Alias for double.
std::uint32_t u32
32 bit unsigned integer
std::uint64_t u64
64 bit unsigned integer
A buffer allocated in USM (Unified Shared Memory)
void copy_from(const DeviceBuffer< T, new_target > &other, size_t copy_size)
Copies the content of another buffer to this one.
A SYCL queue associated with a device and a context.
sycl::event submit(Fct &&fct)
Submits a kernel to the SYCL queue.
DeviceQueue & get_queue(u32 id=0)
Get a reference to a DeviceQueue.
Class to manage a list of SYCL events.
Definition EventList.hpp:31
Represents a collection of objects distributed across patches identified by a u64 id.
Class Timer measures the time elapsed since the timer was started.
Definition time.hpp:96
void end()
Stops the timer and stores the elapsed time in nanoseconds.
Definition time.hpp:111
f64 elasped_sec() const
Converts the stored nanosecond time to a floating point representation in seconds.
Definition time.hpp:123
void start()
Starts the timer.
Definition time.hpp:106
Structure describing a field in a patch data layout.
std::string name
The name of the field.
u32 nvar
The number of variables of the field per object.
u32 get_field_idx(const std::string &field_name) const
Get the field id if matching name & type.
PatchDataLayer container class, the layout is described in patchdata_layout.
virtual void ensure_sizes(const shambase::DistributedData< u32 > &sizes)
Ensure that the sizes of the patches in the field match the given sizes (Can resize the underlying fi...
Definition Field.hpp:92
This header file contains utility functions related to exception handling in the code.
void throw_with_loc(std::string message, SourceLocation loc=SourceLocation{})
Throw an exception and append the source location to it.
void throw_unimplemented(SourceLocation loc=SourceLocation{})
Throw a std::runtime_error saying that the function is unimplemented.
i32 world_rank()
Gives the rank of the current process in the MPI communicator.
Definition worldInfo.cpp:40
namespace for math utility
Definition AABB.hpp:26
namespace for the sph model modules
namespace for the main framework
Definition __init__.py:1
Patch object that contain generic patch information.
Definition Patch.hpp:33
u64 id_patch
unique key that identify the patch
Definition Patch.hpp:86