31 template<
class Tvec,
class Tfield,
template<
class>
class SPHKernel>
32 shamrock::solvergraph::Field<Tfield> RenderFieldGetter<Tvec, Tfield, SPHKernel>::build_field(
33 std::string field_name,
34 std::optional<std::function<py::array_t<Tfield>(
size_t, pybind11::dict &)>> custom_getter) {
36 if (field_name !=
"custom" && custom_getter.has_value()) {
38 "custom_getter is only supported for the custom field");
41 shambase::DistributedData<u32>
sizes{};
43 using namespace shamrock;
44 using namespace shamrock::patch;
46 scheduler().for_each_patchdata_nonempty([&](
const Patch p, PatchDataLayer &pdat) {
47 sizes.add_obj(p.id_patch, pdat.get_obj_cnt());
50 auto make_field = [&](
u32 nvar, std::string name, std::string texsymbol) {
51 shamrock::solvergraph::Field<Tfield> ret
52 = shamrock::solvergraph::Field<Tfield>(nvar, name, texsymbol);
57 if constexpr (std::is_same_v<Tfield, f64>) {
58 if (field_name ==
"rho" && std::is_same_v<Tscal, Tfield>) {
60 auto density = make_field(1,
"rho",
"rho");
62 scheduler().for_each_patchdata_nonempty([&](
const Patch p, PatchDataLayer &pdat) {
63 shamlog_debug_ln(
"sph::vtk",
"compute rho field for patch ", p.id_patch);
66 = pdat.get_field<Tscal>(pdat.pdl().get_field_idx<Tscal>(
"hpart")).get_buf();
67 auto &buf_rho =
density.get_buf(p.id_patch);
69 sham::DeviceQueue &q = shamsys::instance::get_compute_scheduler().
get_queue();
71 sham::EventList depends_list;
73 auto acc_h = buf_h.get_read_access(depends_list);
74 auto acc_rho = buf_rho.get_write_access(depends_list);
76 auto e = q.
submit(depends_list, [&](sycl::handler &cgh) {
77 const Tscal part_mass = solver_config.gpart_mass;
80 sycl::range<1>{pdat.get_obj_cnt()}, [=](sycl::item<1> item) {
81 u32 gid = (
u32) item.get_id();
82 using namespace shamrock::sph;
83 Tscal rho_ha = rho_h(part_mass, acc_h[gid], Kernel::hfactd);
84 acc_rho[gid] = rho_ha;
88 buf_h.complete_event_state(e);
89 buf_rho.complete_event_state(e);
93 }
else if (field_name ==
"inv_hpart" && std::is_same_v<Tscal, Tfield>) {
95 auto inv_hpart = make_field(1,
"inv_hpart",
"inv_hpart");
97 scheduler().for_each_patchdata_nonempty([&](
const Patch p, PatchDataLayer &pdat) {
98 shamlog_debug_ln(
"sph::vtk",
"compute inv_hpart field for patch ", p.id_patch);
101 = pdat.get_field<Tscal>(pdat.pdl().get_field_idx<Tscal>(
"hpart")).get_buf();
102 auto &buf_inv_hpart = inv_hpart.get_buf(p.id_patch);
104 sham::DeviceQueue &q = shamsys::instance::get_compute_scheduler().
get_queue();
106 sham::EventList depends_list;
108 auto acc_h = buf_h.get_read_access(depends_list);
109 auto acc_inv_hpart = buf_inv_hpart.get_write_access(depends_list);
111 auto e = q.
submit(depends_list, [&](sycl::handler &cgh) {
113 sycl::range<1>{pdat.get_obj_cnt()}, [=](sycl::item<1> item) {
114 u32 gid = (
u32) item.get_id();
115 using namespace shamrock::sph;
116 acc_inv_hpart[gid] = 1.0 / acc_h[gid];
120 buf_h.complete_event_state(e);
121 buf_inv_hpart.complete_event_state(e);
125 }
else if (field_name ==
"unity" && std::is_same_v<Tscal, Tfield>) {
126 using namespace shamrock;
127 using namespace shamrock::patch;
129 auto unity = make_field(1,
"unity",
"unity");
131 unity.get_buf(id_patch).fill(1);
135 }
else if (field_name ==
"custom" && custom_getter.has_value()) {
136 std::function<py::array_t<Tfield>(
size_t, pybind11::dict &)> &field_source_getter
137 = custom_getter.value();
139 auto custom = make_field(1,
"custom",
"custom");
141 shambase::Timer timer;
144 scheduler().for_each_patchdata_nonempty([&](
const Patch p, PatchDataLayer &pdat) {
145 shamlog_debug_ln(
"sph::vtk",
"compute custom field for patch ", p.id_patch);
147 auto &buf_custom = custom.get_buf(p.id_patch);
149 sham::DeviceQueue &q = shamsys::instance::get_compute_scheduler().
get_queue();
151 py::dict dic_out = shamrock::pdat_to_dic(pdat);
152 std::vector<Tfield> acc_custom = buf_custom.copy_to_stdvec();
154 py::array_t<Tfield> custom_array
155 = field_source_getter(pdat.get_obj_cnt(), dic_out);
157 if (acc_custom.size() != custom_array.size()) {
159 "custom_array size does not match the number of particles");
162 acc_custom = custom_array.template cast<std::vector<Tfield>>();
164 buf_custom.copy_from_stdvec(acc_custom);
169 f64 worse_time_rank = shamalgs::collective::allreduce_max(timer.
elapsed_sec());
173 "sph::RenderFieldGetter",
174 "compute custom field took : ",
183 auto field_source_getter
184 = [&](
const shamrock::patch::Patch cur_p,
185 shamrock::patch::PatchDataLayer &pdat) ->
const sham::DeviceBuffer<Tfield> & {
186 return pdat.get_field<Tfield>(pdat.pdl().get_field_idx<Tfield>(field_name)).get_buf();
189 FieldDescriptor<Tfield> desc = scheduler().pdl_old().template get_field<Tfield>(field_name);
190 u32 ifield = scheduler().pdl_old().template get_field_idx<Tfield>(field_name);
196 auto ret = make_field(1, desc.
name, desc.
name);
198 scheduler().for_each_patchdata_nonempty([&](
const Patch p, PatchDataLayer &pdat) {
199 sham::DeviceBuffer<Tfield> &buf = ret.get_buf(p.id_patch);
201 pdat.get_field<Tfield>(pdat.pdl().
get_field_idx<Tfield>(field_name)).get_buf());
207 template<
class Tvec,
class Tfield,
template<
class>
class SPHKernel>
208 auto RenderFieldGetter<Tvec, Tfield, SPHKernel>::runner_function(
209 std::string field_name,
211 std::optional<std::function<py::array_t<Tfield>(
size_t, pybind11::dict &)>> custom_getter)
212 -> sham::DeviceBuffer<Tfield> {
214 auto field = build_field(field_name, custom_getter);
216 auto field_source_getter
217 = [&](
const shamrock::patch::Patch cur_p,
218 shamrock::patch::PatchDataLayer &pdat) ->
const sham::DeviceBuffer<Tfield> & {
219 return field.get_buf(cur_p.
id_patch);
222 return lambda(field_source_getter);
constexpr const char * density
Density \rho (derived from h).
constexpr const char * sizes
Temporary sizes for h-iteration.
double f64
Alias for double.
std::uint32_t u32
32 bit unsigned integer
std::uint64_t u64
64 bit unsigned integer
void copy_from(const DeviceBuffer< T, new_target > &other, size_t copy_size)
Copies the content of another buffer to this one.
sycl::event submit(Fct &&fct)
Submits a kernel to the SYCL queue.
DeviceQueue & get_queue(u32 id=0)
Get a reference to a DeviceQueue.
f64 elapsed_sec() const
Converts the stored nanosecond time to a floating point representation in seconds.
void start()
Starts the timer.
void stop()
Stops the timer and stores the elapsed time in nanoseconds.
std::string name
The name of the field.
u32 nvar
The number of variables of the field per object.
u32 get_field_idx(const std::string &field_name) const
Get the field id if matching name & type.
virtual void ensure_sizes(const shambase::DistributedData< u32 > &sizes)
Ensure that the sizes of the patches in the field match the given sizes (Can resize the underlying fi...
This header file contains utility functions related to exception handling in the code.
ExcptTypes make_except_with_loc(std::string message, SourceLocation loc=SourceLocation{})
Create an exception with a message and a location.
void throw_unimplemented(SourceLocation loc=SourceLocation{})
Throw a std::runtime_error saying that the function is unimplemented.
i32 world_rank()
Gives the rank of the current process in the MPI communicator.
namespace for math utility
namespace for the sph model modules
u64 id_patch
unique key that identify the patch