36 using Tscal = shambase::VecComponent<Tvec>;
38 using mdspan_rank_1 = std::mdspan<Tscal, std::dextents<u32, 1>>;
39 using mdspan_rank_3 = std::mdspan<Tscal, std::dextents<u32, 3>>;
41 using const_mdspan_rank_1 = std::mdspan<const Tscal, std::dextents<u32, 1>>;
42 using const_mdspan_rank_3 = std::mdspan<const Tscal, std::dextents<u32, 3>>;
54 const Tscal *__restrict massgrid_ptr,
55 const Tscal *__restrict tensor_tabflux_coag,
57 const Tscal *__restrict s_j,
58 const Tvec *__restrict delta_v_j,
59 Tscal *__restrict S_coag)
const {
61 auto range = sycl::nd_range<1>{corrected_len, group_size};
63 auto local_acc_sz_nbins = sycl::range<1>{group_size * nbins};
65 auto true_size = this->true_size;
66 auto rho_eps = this->rho_eps;
67 auto dv_max = this->dv_max;
69 return [=, nbins = this->nbins](sycl::handler &cgh) {
70 auto gij_acc = sycl::local_accessor<Tscal>{local_acc_sz_nbins, cgh};
71 auto flux_acc = sycl::local_accessor<Tscal>{local_acc_sz_nbins, cgh};
73 cgh.parallel_for(range, [=](sycl::nd_item<1> tid) {
74 const u64 id_a = tid.get_global_linear_id();
75 const u64 lid = tid.get_local_linear_id();
77 if (id_a >= true_size) {
81 u32 id_a_d = id_a * nbins;
84 const_mdspan_rank_3 tabflux_coag(tensor_tabflux_coag, nbins, nbins, nbins);
85 const_mdspan_rank_1 massgrid(massgrid_ptr, nbins + 1);
88 auto gij_loc = &(gij_acc[nbins * lid]);
89 auto flux_loc = &(flux_acc[nbins * lid]);
91 mdspan_rank_1 gij(gij_loc, nbins);
92 mdspan_rank_1 flux(flux_loc, nbins);
95 mdspan_rank_1 S_coag_span(S_coag + id_a_d, nbins);
98 auto rho_dust = [&](
int j) {
99 auto tmp = s_j[id_a_d + j];
103 auto dv = [&, delta_v = delta_v_j + id_a_d](
int i,
int j) {
105 auto tmp = sycl::length(delta_v[j] - delta_v[i]);
106 return (tmp > dv_max) ? 0 : tmp;
112 shamphys::coala_k0_source_term(
132 auto edges = get_edges();
134 auto s_j_spans = edges.s_j.get_spans();
135 auto delta_v_j_spans = edges.delta_v_j.get_spans();
137 auto counts = edges.part_counts.indexes;
139 edges.S_coag.ensure_sizes(counts);
140 auto S_coag_spans = edges.S_coag.get_spans();
142 Tscal rho_eps = edges.rhodust_eps.value;
143 Tscal dv_max = edges.dv_max.value;
144 const std::vector<Tscal> &massgrid = edges.massgrid.value;
145 const std::vector<Tscal> &tensor_tabflux_coag = edges.tensor_tabflux_coag.value;
147 auto dev_sched = shamsys::instance::get_compute_scheduler_ptr();
158 counts.for_each([&](
u64 id_patch,
u64 count) {
160 u32 corrected_len = group_cnt * group_size;
162 sham::kernel_call_hndl(
166 tensor_tabflux_coag_buf,
167 s_j_spans.get(id_patch),
168 delta_v_j_spans.get(id_patch)},
175 .corrected_len = corrected_len,
176 .group_size = group_size,
177 .true_size =
u32(count)});