31 template<
class Tvec,
class Tfield,
template<
class>
class SPHKernel>
33 std::string field_name,
34 std::optional<std::function<py::array_t<Tfield>(
size_t, pybind11::dict &)>> custom_getter) {
36 if (field_name !=
"custom" && custom_getter.has_value()) {
38 "custom_getter is only supported for the custom field");
44 using namespace shamrock::patch;
47 sizes.add_obj(p.id_patch, pdat.get_obj_cnt());
50 auto make_field = [&](
u32 nvar, std::string name, std::string texsymbol) {
57 if constexpr (std::is_same_v<Tfield, f64>) {
58 if (field_name ==
"rho" && std::is_same_v<Tscal, Tfield>) {
60 auto density = make_field(1,
"rho",
"rho");
63 shamlog_debug_ln(
"sph::vtk",
"compute rho field for patch ", p.id_patch);
66 = pdat.get_field<Tscal>(pdat.pdl().
get_field_idx<Tscal>(
"hpart")).get_buf();
67 auto &buf_rho =
density.get_buf(p.id_patch);
73 auto acc_h = buf_h.get_read_access(depends_list);
74 auto acc_rho = buf_rho.get_write_access(depends_list);
76 auto e = q.
submit(depends_list, [&](sycl::handler &cgh) {
77 const Tscal part_mass = solver_config.gpart_mass;
80 sycl::range<1>{pdat.get_obj_cnt()}, [=](sycl::item<1> item) {
81 u32 gid = (
u32) item.get_id();
82 using namespace shamrock::sph;
83 Tscal rho_ha = rho_h(part_mass, acc_h[gid], Kernel::hfactd);
84 acc_rho[gid] = rho_ha;
88 buf_h.complete_event_state(e);
89 buf_rho.complete_event_state(e);
93 }
else if (field_name ==
"inv_hpart" && std::is_same_v<Tscal, Tfield>) {
95 auto inv_hpart = make_field(1,
"inv_hpart",
"inv_hpart");
98 shamlog_debug_ln(
"sph::vtk",
"compute inv_hpart field for patch ", p.id_patch);
101 = pdat.get_field<Tscal>(pdat.pdl().
get_field_idx<Tscal>(
"hpart")).get_buf();
102 auto &buf_inv_hpart = inv_hpart.get_buf(p.id_patch);
108 auto acc_h = buf_h.get_read_access(depends_list);
109 auto acc_inv_hpart = buf_inv_hpart.get_write_access(depends_list);
111 auto e = q.
submit(depends_list, [&](sycl::handler &cgh) {
113 sycl::range<1>{pdat.get_obj_cnt()}, [=](sycl::item<1> item) {
114 u32 gid = (
u32) item.get_id();
115 using namespace shamrock::sph;
116 acc_inv_hpart[gid] = 1.0 / acc_h[gid];
120 buf_h.complete_event_state(e);
121 buf_inv_hpart.complete_event_state(e);
125 }
else if (field_name ==
"unity" && std::is_same_v<Tscal, Tfield>) {
127 using namespace shamrock::patch;
129 auto unity = make_field(1,
"unity",
"unity");
131 unity.get_buf(id_patch).fill(1);
135 }
else if (field_name ==
"custom" && custom_getter.has_value()) {
136 std::function<py::array_t<Tfield>(
size_t, pybind11::dict &)> &field_source_getter
137 = custom_getter.value();
139 auto custom = make_field(1,
"custom",
"custom");
145 shamlog_debug_ln(
"sph::vtk",
"compute custom field for patch ", p.id_patch);
147 auto &buf_custom = custom.get_buf(p.id_patch);
151 py::dict dic_out = shamrock::pdat_to_dic(pdat);
152 std::vector<Tfield> acc_custom = buf_custom.copy_to_stdvec();
154 py::array_t<Tfield> custom_array
155 = field_source_getter(pdat.get_obj_cnt(), dic_out);
157 if (acc_custom.size() != custom_array.size()) {
159 "custom_array size does not match the number of particles");
162 acc_custom = custom_array.template cast<std::vector<Tfield>>();
164 buf_custom.copy_from_stdvec(acc_custom);
169 f64 worse_time_rank = shamalgs::collective::allreduce_max(timer.
elasped_sec());
173 "sph::RenderFieldGetter",
174 "compute custom field took : ",
183 auto field_source_getter
186 return pdat.get_field<Tfield>(pdat.pdl().
get_field_idx<Tfield>(field_name)).get_buf();
190 u32 ifield = scheduler().pdl_old().template get_field_idx<Tfield>(field_name);
196 auto ret = make_field(1, desc.
name, desc.
name);
201 pdat.get_field<Tfield>(pdat.pdl().
get_field_idx<Tfield>(field_name)).get_buf());
207 template<
class Tvec,
class Tfield,
template<
class>
class SPHKernel>
208 auto RenderFieldGetter<Tvec, Tfield, SPHKernel>::runner_function(
209 std::string field_name,
211 std::optional<std::function<py::array_t<Tfield>(
size_t, pybind11::dict &)>> custom_getter)
214 auto field = build_field(field_name, custom_getter);
216 auto field_source_getter
219 return field.get_buf(cur_p.
id_patch);
222 return lambda(field_source_getter);
constexpr const char * density
Density \rho (derived from h)
constexpr const char * sizes
Temporary sizes for h-iteration.
double f64
Alias for double.
std::uint32_t u32
32 bit unsigned integer
std::uint64_t u64
64 bit unsigned integer
A buffer allocated in USM (Unified Shared Memory)
void copy_from(const DeviceBuffer< T, new_target > &other, size_t copy_size)
Copies the content of another buffer to this one.
A SYCL queue associated with a device and a context.
sycl::event submit(Fct &&fct)
Submits a kernel to the SYCL queue.
DeviceQueue & get_queue(u32 id=0)
Get a reference to a DeviceQueue.
Class to manage a list of SYCL events.
Represents a collection of objects distributed across patches identified by a u64 id.
Class Timer measures the time elapsed since the timer was started.
void end()
Stops the timer and stores the elapsed time in nanoseconds.
f64 elasped_sec() const
Converts the stored nanosecond time to a floating point representation in seconds.
void start()
Starts the timer.
Structure describing a field in a patch data layout.
std::string name
The name of the field.
u32 nvar
The number of variables of the field per object.
u32 get_field_idx(const std::string &field_name) const
Get the field id if matching name & type.
PatchDataLayer container class, the layout is described in patchdata_layout.
virtual void ensure_sizes(const shambase::DistributedData< u32 > &sizes)
Ensure that the sizes of the patches in the field match the given sizes (Can resize the underlying fi...
This header file contains utility functions related to exception handling in the code.
void throw_with_loc(std::string message, SourceLocation loc=SourceLocation{})
Throw an exception and append the source location to it.
void throw_unimplemented(SourceLocation loc=SourceLocation{})
Throw a std::runtime_error saying that the function is unimplemented.
i32 world_rank()
Gives the rank of the current process in the MPI communicator.
namespace for math utility
namespace for the sph model modules
namespace for the main framework
Patch object that contain generic patch information.
u64 id_patch
unique key that identify the patch