26 using Tscal = shambase::VecComponent<Tvec>;
27 using Kernel = SPHKernel<Tscal>;
28 static constexpr Tscal hfactd = Kernel::hfactd;
29 static constexpr Tscal Rkern = Kernel::Rkern;
30 static constexpr Tscal Rker2 = Rkern * Rkern;
35 inline void operator()(
38 const Tvec *__restrict xyz,
39 const Tscal *__restrict hpart,
40 const Tvec *__restrict vxyz,
41 const Tscal *__restrict omega,
42 const Tscal *__restrict pressure,
43 const Tscal *__restrict s_j,
44 const Tscal *__restrict Ttilde_sj,
47 Tscal *__restrict ds_j_dt)
const {
49 u32 id_a = thread_id / ndust;
50 u32 jdust = thread_id % ndust;
52 Tscal h_a = hpart[id_a];
53 Tvec xyz_a = xyz[id_a];
54 Tvec vxyz_a = vxyz[id_a];
55 Tscal P_a = pressure[id_a];
56 Tscal omega_a = omega[id_a];
57 Tscal s_j_a = s_j[thread_id];
58 Tscal Ttilde_sj_a = Ttilde_sj[thread_id];
60 using namespace shamrock::sph;
61 Tscal rho_a = rho_h(pmass, h_a, Kernel::hfactd);
62 Tscal rho_a_sq = rho_a * rho_a;
63 Tscal rho_a_inv = 1. / rho_a;
64 Tscal omega_a_rho_a_inv = 1 / (omega_a * rho_a);
70 particle_looper.for_each_object(id_a, [&](
u32 id_b) {
71 Tvec dr = xyz_a - xyz[id_b];
72 Tscal rab2 = sycl::dot(dr, dr);
73 Tscal h_b = hpart[id_b];
75 if (rab2 > h_a * h_a * Rker2 && rab2 > h_b * h_b * Rker2) {
79 Tvec vxyz_b = vxyz[id_b];
80 Tscal P_b = pressure[id_b];
81 Tscal omega_b = omega[id_b];
82 Tscal s_j_b = s_j[id_b * ndust + jdust];
83 Tscal Ttilde_sj_b = Ttilde_sj[id_b * ndust + jdust];
85 Tscal rab = sycl::sqrt(rab2);
88 Tscal rho_b = rho_h(pmass, h_b, Kernel::hfactd);
90 Tscal Fab_a = Kernel::dW_3d(rab, h_a);
91 Tscal Fab_b = Kernel::dW_3d(rab, h_b);
93 Tvec v_ab = vxyz_a - vxyz_b;
95 Tvec r_ab_unit = dr * rab_inv_sat;
97 Tscal F_ab_bar = (Fab_a + Fab_b) / 2;
98 Tscal delta_P = P_a - P_b;
99 Tscal Ts_weighted = (Ttilde_sj_a / rho_a + Ttilde_sj_b / rho_b);
101 term1 += (pmass * s_j_b / rho_b) * Ts_weighted * delta_P * F_ab_bar * rab_inv_sat;
102 term2 += pmass * sham::dot(v_ab, r_ab_unit * Fab_a);
106 Tscal ds_j_dt_a = Tscal{-0.5} * term1 + (s_j_a / (2 * rho_a * omega_a)) * term2;
109 ds_j_dt_a *= (s_j_a < 0 && ds_j_dt_a < 0) ? 0 : 1;
112 ds_j_dt_a += (s_j_a < 0) ? -s_j_a / (10 * Ttilde_sj_a) : 0;
114 ds_j_dt[thread_id] = ds_j_dt_a;
124 auto edges = get_edges();
126 auto &part_counts_with_ghost = edges.part_counts_with_ghost.indexes;
127 auto &part_counts = edges.part_counts.indexes;
130 edges.xyz.check_sizes(part_counts_with_ghost);
131 edges.hpart.check_sizes(part_counts_with_ghost);
132 edges.vxyz.check_sizes(part_counts_with_ghost);
133 edges.omega.check_sizes(part_counts_with_ghost);
134 edges.pressure.check_sizes(part_counts_with_ghost);
135 edges.s_j.check_sizes(part_counts_with_ghost);
136 edges.Ttilde_sj.check_sizes(part_counts_with_ghost);
139 edges.ds_j_dt.ensure_sizes(part_counts);
141 const Tscal pmass = edges.gpart_mass.value;
145 auto total_specie_count = part_counts.template map<u32>([&](
u64 id,
u32 count) {
146 return count * ndust;
151 shamsys::instance::get_compute_scheduler_ptr(),
153 edges.xyz.get_spans(),
154 edges.hpart.get_spans(),
155 edges.vxyz.get_spans(),
156 edges.omega.get_spans(),
157 edges.pressure.get_spans(),
158 edges.s_j.get_spans(),
159 edges.Ttilde_sj.get_spans(),
163 ComputeKernel{pmass, ndust});
void distributed_data_kernel_call(sham::DeviceScheduler_ptr dev_sched, RefIn in, RefOut in_out, const shambase::DistributedData< index_t > &thread_counts, Functor &&func)
A variant of sham::kernel_call for distributed data.