32 using Tscal = shambase::VecComponent<Tvec>;
35 using Solver = Solver<Tvec, SPHKernel>;
42 : model(model), ctx(model.ctx), solver(model.solver) {};
44 auto get_total_momentum() -> Tvec {
46 auto dev_sched_ptr = shamsys::instance::get_compute_scheduler_ptr();
49 const u32 ivxyz = sched.pdl_old().template get_field_idx<Tvec>(
"vxyz");
50 const Tscal pmass = solver.solver_config.gpart_mass;
52 Tvec total_momentum = {};
58 u32 len = pdat.get_obj_cnt();
60 total_momentum_part.
resize(len);
70 u32 i,
const Tvec *__restrict vxyz, Tvec *__restrict total_momentum_part) {
71 total_momentum_part[i] = pmass * vxyz[i];
78 Tvec tot_total_momentum = shamalgs::collective::allreduce_sum(total_momentum);
80 if (!solver.storage.sinks.is_empty()) {
81 for (
auto &sink : solver.storage.sinks.get()) {
82 tot_total_momentum += sink.mass * sink.velocity;
86 return tot_total_momentum;
void kernel_call(sham::DeviceQueue &q, RefIn in, RefOut in_out, u32 n, Functor &&func, SourceLocation &&callsite=SourceLocation{})
Submit a kernel to a SYCL queue.
T sum(const sham::DeviceScheduler_ptr &sched, const sham::DeviceBuffer< T > &buf1, u32 start_id, u32 end_id)
Compute the sum of elements in a device buffer within a specified range.
T & get_check_ref(const std::unique_ptr< T > &ptr, SourceLocation loc=SourceLocation())
Takes a std::unique_ptr and returns a reference to the object it holds. It throws a std::runtime_erro...