Shamrock 2025.10.0
Astrophysical Code
Loading...
Searching...
No Matches
compute_neigh_graph.hpp
Go to the documentation of this file.
1// -------------------------------------------------------//
2//
3// SHAMROCK code for hydrodynamics
4// Copyright (c) 2021-2026 Timothée David--Cléris <tim.shamrock@proton.me>
5// SPDX-License-Identifier: CeCILL Free Software License Agreement v2.1
6// Shamrock is licensed under the CeCILL 2.1 License, see LICENSE for more information
7//
8// -------------------------------------------------------//
9
10#pragma once
11
18
22#include "shammath/AABB.hpp"
26
27namespace shammodels::basegodunov::modules::details {
28
40 template<class NeighFindKernel, class... Args>
42 const sham::DeviceScheduler_ptr &dev_sched, u32 graph_nodes, Args &&...args) {
43
44 auto &q = dev_sched->get_queue();
45
46 NeighFindKernel kergen(std::forward<Args>(args)...);
47
48 // [i] is the number of link for block i in mpdat (last value is 0)
49 sham::DeviceBuffer<u32> link_counts(graph_nodes + 1, dev_sched);
50
51 sham::EventList deps;
52 auto ker = kergen.get_read_access(deps);
53 auto ptr_link_cnt = link_counts.get_write_access(deps);
54
55 // fill buffer with number of link in the block graph
56 auto e = q.submit(deps, [&](sycl::handler &cgh) {
57 shambase::parallel_for(cgh, graph_nodes, "count block graph link", [=](u64 gid) {
58 u32 id_a = (u32) gid;
59 u32 block_found_count = 0;
60
61 ker.for_each_other_index(id_a, [&](u32 id_b) {
62 block_found_count++;
63 });
64
65 ptr_link_cnt[id_a] = block_found_count;
66 });
67 });
68
69 link_counts.complete_event_state(e);
70 kergen.complete_event_state(e);
71
72 // set the last val to 0 so that the last slot after exclusive scan is the sum
73 link_counts.set_val_at_idx(graph_nodes, 0);
74
75 sham::DeviceBuffer<u32> link_cnt_offsets
76 = shamalgs::numeric::scan_exclusive(dev_sched, link_counts, graph_nodes + 1);
77
78 u32 link_cnt = link_cnt_offsets.get_val_at_idx(graph_nodes);
79
80 sham::DeviceBuffer<u32> ids_links(link_cnt, dev_sched);
81
82 sham::EventList deps2;
83 auto cnt_offsets = link_cnt_offsets.get_read_access(deps2);
84 auto ker2 = kergen.get_read_access(deps);
85 auto links = ids_links.get_write_access(deps2);
86
87 // find the neigh ids
88 auto e2 = q.submit(deps2, [&](sycl::handler &cgh) {
89 shambase::parallel_for(cgh, graph_nodes, "get ids block graph link", [=](u64 gid) {
90 u32 id_a = (u32) gid;
91
92 u32 next_link_idx = cnt_offsets[id_a];
93
94 ker2.for_each_other_index(id_a, [&](u32 id_b) {
95 links[next_link_idx] = id_b;
96 next_link_idx++;
97 });
98 });
99 });
100
101 link_cnt_offsets.complete_event_state(e2);
102 ids_links.complete_event_state(e2);
103 kergen.complete_event_state(e2);
104
106 return Graph(
107 Graph{
108 .node_link_offset = std::move(link_cnt_offsets),
109 .node_links = std::move(ids_links),
110 .link_count = link_cnt,
111 .obj_cnt = graph_nodes});
112 };
113
127 template<class NeighFindKernel, class... Args>
129 const sham::DeviceScheduler_ptr &dev_sched, u32 graph_nodes, Args &&...args) {
130
131 auto &q = dev_sched->get_queue();
132
133 // [i] is the number of link for block i in mpdat (last value is 0)
134 sham::DeviceBuffer<u32> link_counts(graph_nodes + 1, dev_sched);
135
136 sham::EventList deps;
137 auto ptr_link_cnt = link_counts.get_write_access(deps);
138
139 // fill buffer with number of link in the block graph
140 auto e = q.submit(deps, [&](sycl::handler &cgh) {
141 NeighFindKernel ker(cgh, std::forward<Args>(args)...);
142 shambase::parallel_for(cgh, graph_nodes, "count block graph link", [=](u64 gid) {
143 u32 id_a = (u32) gid;
144 u32 block_found_count = 0;
145
146 ker.for_each_other_index(id_a, [&](u32 id_b) {
147 block_found_count++;
148 });
149
150 ptr_link_cnt[id_a] = block_found_count;
151 });
152 });
153
154 link_counts.complete_event_state(e);
155
156 // set the last val to 0 so that the last slot after exclusive scan is the sum
157 link_counts.set_val_at_idx(graph_nodes, 0);
158
159 sham::DeviceBuffer<u32> link_cnt_offsets
160 = shamalgs::numeric::scan_exclusive(dev_sched, link_counts, graph_nodes + 1);
161
162 u32 link_cnt = link_cnt_offsets.get_val_at_idx(graph_nodes);
163
164 sham::DeviceBuffer<u32> ids_links(link_cnt, dev_sched);
165
166 sham::EventList deps2;
167 auto cnt_offsets = link_cnt_offsets.get_read_access(deps2);
168 auto links = ids_links.get_write_access(deps2);
169
170 // find the neigh ids
171 auto e2 = q.submit(deps2, [&](sycl::handler &cgh) {
172 NeighFindKernel ker(cgh, std::forward<Args>(args)...);
173 shambase::parallel_for(cgh, graph_nodes, "get ids block graph link", [=](u64 gid) {
174 u32 id_a = (u32) gid;
175
176 u32 next_link_idx = cnt_offsets[id_a];
177
178 ker.for_each_other_index(id_a, [&](u32 id_b) {
179 links[next_link_idx] = id_b;
180 next_link_idx++;
181 });
182 });
183 });
184
185 link_cnt_offsets.complete_event_state(e2);
186 ids_links.complete_event_state(e2);
187
189 return Graph(
190 Graph{
191 .node_link_offset = std::move(link_cnt_offsets),
192 .node_links = std::move(ids_links),
193 .link_count = link_cnt,
194 .obj_cnt = graph_nodes});
195 };
196
197} // namespace shammodels::basegodunov::modules::details
std::uint32_t u32
32 bit unsigned integer
std::uint64_t u64
64 bit unsigned integer
A buffer allocated in USM (Unified Shared Memory).
void complete_event_state(sycl::event e) const
Complete the event state of the buffer.
T * get_write_access(sham::EventList &depends_list, SourceLocation src_loc=SourceLocation{})
Get a read-write pointer to the buffer's data.
T get_val_at_idx(size_t idx) const
Get the value at a given index in the buffer.
const T * get_read_access(sham::EventList &depends_list, SourceLocation src_loc=SourceLocation{}) const
Get a read-only pointer to the buffer's data.
Class to manage a list of SYCL events.
Definition EventList.hpp:31
shammodels::basegodunov::modules::NeighGraph compute_neigh_graph_deprecated(const sham::DeviceScheduler_ptr &dev_sched, u32 graph_nodes, Args &&...args)
Create a neighbour graph using a class that will list the ids of the found neighbourgh NeighFindKerne...
shammodels::basegodunov::modules::NeighGraph compute_neigh_graph(const sham::DeviceScheduler_ptr &dev_sched, u32 graph_nodes, Args &&...args)
Create a neighbour graph using a class that will list the ids of the found neighbourgh NeighFindKerne...
sycl::buffer< T > scan_exclusive(sycl::queue &q, sycl::buffer< T > &buf1, u32 len)
Computes the exclusive sum of elements in a SYCL buffer.
Definition numeric.cpp:35