Shamrock 2025.10.0
Astrophysical Code
Loading...
Searching...
No Matches
compute_neigh_graph.hpp
Go to the documentation of this file.
1// -------------------------------------------------------//
2//
3// SHAMROCK code for hydrodynamics
4// Copyright (c) 2021-2026 Timothée David--Cléris <tim.shamrock@proton.me>
5// SPDX-License-Identifier: CeCILL Free Software License Agreement v2.1
6// Shamrock is licensed under the CeCILL 2.1 License, see LICENSE for more information
7//
8// -------------------------------------------------------//
9
10#pragma once
11
22#include "shammath/AABB.hpp"
26
27namespace shammodels::basegodunov::modules::details {
28
40 template<class NeighFindKernel, class... Args>
42 const sham::DeviceScheduler_ptr &dev_sched, u32 graph_nodes, Args &&...args) {
43
44 auto &q = dev_sched->get_queue();
45
46 NeighFindKernel kergen(std::forward<Args>(args)...);
47
48 // [i] is the number of link for block i in mpdat (last value is 0)
49 sham::DeviceBuffer<u32> link_counts(graph_nodes + 1, dev_sched);
50
51 sham::EventList deps;
52 auto ker = kergen.get_read_access(deps);
53 auto ptr_link_cnt = link_counts.get_write_access(deps);
54
55 // fill buffer with number of link in the block graph
56 auto e = q.submit(deps, [&](sycl::handler &cgh) {
57 shambase::parallel_for(cgh, graph_nodes, "count block graph link", [=](u64 gid) {
58 u32 id_a = (u32) gid;
59 u32 block_found_count = 0;
60
61 ker.for_each_other_index(id_a, [&](u32 id_b) {
62 block_found_count++;
63 });
64
65 ptr_link_cnt[id_a] = block_found_count;
66 });
67 });
68
69 link_counts.complete_event_state(e);
70 kergen.complete_event_state(e);
71
72 // set the last val to 0 so that the last slot after exclusive scan is the sum
73 link_counts.set_val_at_idx(graph_nodes, 0);
74
75 sham::DeviceBuffer<u32> link_cnt_offsets
76 = shamalgs::numeric::scan_exclusive(dev_sched, link_counts, graph_nodes + 1);
77
78 u32 link_cnt = link_cnt_offsets.get_val_at_idx(graph_nodes);
79
80 sham::DeviceBuffer<u32> ids_links(link_cnt, dev_sched);
81
82 sham::EventList deps2;
83 auto cnt_offsets = link_cnt_offsets.get_read_access(deps2);
84 auto ker2 = kergen.get_read_access(deps);
85 auto links = ids_links.get_write_access(deps2);
86
87 // find the neigh ids
88 auto e2 = q.submit(deps2, [&](sycl::handler &cgh) {
89 shambase::parallel_for(cgh, graph_nodes, "get ids block graph link", [=](u64 gid) {
90 u32 id_a = (u32) gid;
91
92 u32 next_link_idx = cnt_offsets[id_a];
93
94 ker2.for_each_other_index(id_a, [&](u32 id_b) {
95 links[next_link_idx] = id_b;
96 next_link_idx++;
97 });
98 });
99 });
100
101 link_cnt_offsets.complete_event_state(e2);
102 ids_links.complete_event_state(e2);
103 kergen.complete_event_state(e2);
104
106 return Graph(
107 Graph{std::move(link_cnt_offsets), std::move(ids_links), link_cnt, graph_nodes});
108 };
109
123 template<class NeighFindKernel, class... Args>
125 const sham::DeviceScheduler_ptr &dev_sched, u32 graph_nodes, Args &&...args) {
126
127 auto &q = dev_sched->get_queue();
128
129 // [i] is the number of link for block i in mpdat (last value is 0)
130 sham::DeviceBuffer<u32> link_counts(graph_nodes + 1, dev_sched);
131
132 sham::EventList deps;
133 auto ptr_link_cnt = link_counts.get_write_access(deps);
134
135 // fill buffer with number of link in the block graph
136 auto e = q.submit(deps, [&](sycl::handler &cgh) {
137 NeighFindKernel ker(cgh, std::forward<Args>(args)...);
138 shambase::parallel_for(cgh, graph_nodes, "count block graph link", [=](u64 gid) {
139 u32 id_a = (u32) gid;
140 u32 block_found_count = 0;
141
142 ker.for_each_other_index(id_a, [&](u32 id_b) {
143 block_found_count++;
144 });
145
146 ptr_link_cnt[id_a] = block_found_count;
147 });
148 });
149
150 link_counts.complete_event_state(e);
151
152 // set the last val to 0 so that the last slot after exclusive scan is the sum
153 link_counts.set_val_at_idx(graph_nodes, 0);
154
155 sham::DeviceBuffer<u32> link_cnt_offsets
156 = shamalgs::numeric::scan_exclusive(dev_sched, link_counts, graph_nodes + 1);
157
158 u32 link_cnt = link_cnt_offsets.get_val_at_idx(graph_nodes);
159
160 sham::DeviceBuffer<u32> ids_links(link_cnt, dev_sched);
161
162 sham::EventList deps2;
163 auto cnt_offsets = link_cnt_offsets.get_read_access(deps2);
164 auto links = ids_links.get_write_access(deps2);
165
166 // find the neigh ids
167 auto e2 = q.submit(deps2, [&](sycl::handler &cgh) {
168 NeighFindKernel ker(cgh, std::forward<Args>(args)...);
169 shambase::parallel_for(cgh, graph_nodes, "get ids block graph link", [=](u64 gid) {
170 u32 id_a = (u32) gid;
171
172 u32 next_link_idx = cnt_offsets[id_a];
173
174 ker.for_each_other_index(id_a, [&](u32 id_b) {
175 links[next_link_idx] = id_b;
176 next_link_idx++;
177 });
178 });
179 });
180
181 link_cnt_offsets.complete_event_state(e2);
182 ids_links.complete_event_state(e2);
183
185 return Graph(
186 Graph{std::move(link_cnt_offsets), std::move(ids_links), link_cnt, graph_nodes});
187 };
188
189} // namespace shammodels::basegodunov::modules::details
std::uint32_t u32
32 bit unsigned integer
std::uint64_t u64
64 bit unsigned integer
A buffer allocated in USM (Unified Shared Memory)
void complete_event_state(sycl::event e) const
Complete the event state of the buffer.
T * get_write_access(sham::EventList &depends_list, SourceLocation src_loc=SourceLocation{})
Get a read-write pointer to the buffer's data.
T get_val_at_idx(size_t idx) const
Get the value at a given index in the buffer.
const T * get_read_access(sham::EventList &depends_list, SourceLocation src_loc=SourceLocation{}) const
Get a read-only pointer to the buffer's data.
Class to manage a list of SYCL events.
Definition EventList.hpp:31
shammodels::basegodunov::modules::NeighGraph compute_neigh_graph_deprecated(const sham::DeviceScheduler_ptr &dev_sched, u32 graph_nodes, Args &&...args)
Create a neighbour graph using a class that will list the ids of the found neighbourgh NeighFindKerne...
shammodels::basegodunov::modules::NeighGraph compute_neigh_graph(const sham::DeviceScheduler_ptr &dev_sched, u32 graph_nodes, Args &&...args)
Create a neighbour graph using a class that will list the ids of the found neighbourgh NeighFindKerne...
sycl::buffer< T > scan_exclusive(sycl::queue &q, sycl::buffer< T > &buf1, u32 len)
Computes the exclusive sum of elements in a SYCL buffer.
Definition numeric.cpp:35