Shamrock 2025.10.0
Astrophysical Code
Loading...
Searching...
No Matches
AnalysisTotalMomentum.hpp
Go to the documentation of this file.
1// -------------------------------------------------------//
2//
3// SHAMROCK code for hydrodynamics
4// Copyright (c) 2021-2026 Timothée David--Cléris <tim.shamrock@proton.me>
5// SPDX-License-Identifier: CeCILL Free Software License Agreement v2.1
6// Shamrock is licensed under the CeCILL 2.1 License, see LICENSE for more information
7//
8// -------------------------------------------------------//
9
10#pragma once
11
19#include "shambase/memory.hpp"
26
28
29 template<class Tvec, template<class> class SPHKernel>
31 public:
32 using Tscal = shambase::VecComponent<Tvec>;
34
35 using Solver = Solver<Tvec, SPHKernel>;
36
38 Solver &solver;
39 ShamrockCtx &ctx;
40
42 : model(model), ctx(model.ctx), solver(model.solver) {};
43
44 auto get_total_momentum() -> Tvec {
45 PatchScheduler &sched = shambase::get_check_ref(ctx.sched);
46 auto dev_sched_ptr = shamsys::instance::get_compute_scheduler_ptr();
47 sham::DeviceQueue &q = shambase::get_check_ref(dev_sched_ptr).get_queue();
48
49 const u32 ivxyz = sched.pdl_old().template get_field_idx<Tvec>("vxyz");
50 const Tscal pmass = solver.solver_config.gpart_mass;
51
52 Tvec total_momentum = {};
53
54 sham::DeviceBuffer<Tvec> total_momentum_part(0, dev_sched_ptr);
55
56 sched.for_each_patchdata_nonempty([&](const shamrock::patch::Patch p,
58 u32 len = pdat.get_obj_cnt();
59
60 total_momentum_part.resize(len);
61
62 sham::DeviceBuffer<Tvec> &vxyz_buf = pdat.get_field_buf_ref<Tvec>(ivxyz);
63
65 q,
66 sham::MultiRef{vxyz_buf},
67 sham::MultiRef{total_momentum_part},
68 len,
69 [pmass](
70 u32 i, const Tvec *__restrict vxyz, Tvec *__restrict total_momentum_part) {
71 total_momentum_part[i] = pmass * vxyz[i];
72 });
73
74 total_momentum
75 += shamalgs::primitives::sum(dev_sched_ptr, total_momentum_part, 0, len);
76 });
77
78 Tvec tot_total_momentum = shamalgs::collective::allreduce_sum(total_momentum);
79
80 if (!solver.storage.sinks.is_empty()) {
81 for (auto &sink : solver.storage.sinks.get()) {
82 tot_total_momentum += sink.mass * sink.velocity;
83 }
84 }
85
86 return tot_total_momentum;
87 }
88 };
89} // namespace shammodels::sph::modules
MPI scheduler.
std::uint32_t u32
32 bit unsigned integer
The MPI scheduler.
A buffer allocated in USM (Unified Shared Memory)
void resize(size_t new_size, bool keep_data=true)
Resizes the buffer to a given size.
A SYCL queue associated with a device and a context.
The shamrock SPH model.
Definition Model.hpp:55
PatchDataLayer container class, the layout is described in patchdata_layout.
void kernel_call(sham::DeviceQueue &q, RefIn in, RefOut in_out, u32 n, Functor &&func, SourceLocation &&callsite=SourceLocation{})
Submit a kernel to a SYCL queue.
T sum(const sham::DeviceScheduler_ptr &sched, const sham::DeviceBuffer< T > &buf1, u32 start_id, u32 end_id)
Compute the sum of elements in a device buffer within a specified range.
T & get_check_ref(const std::unique_ptr< T > &ptr, SourceLocation loc=SourceLocation())
Takes a std::unique_ptr and returns a reference to the object it holds. It throws a std::runtime_erro...
Definition memory.hpp:110
namespace for the sph model modules
A class that references multiple buffers or similar objects.
Patch object that contain generic patch information.
Definition Patch.hpp:33