Shamrock 2025.10.0
Astrophysical Code
Loading...
Searching...
No Matches
KarrasRadixTreeField.hpp
Go to the documentation of this file.
1// -------------------------------------------------------//
2//
3// SHAMROCK code for hydrodynamics
4// Copyright (c) 2021-2026 Timothée David--Cléris <tim.shamrock@proton.me>
5// SPDX-License-Identifier: CeCILL Free Software License Agreement v2.1
6// Shamrock is licensed under the CeCILL 2.1 License, see LICENSE for more information
7//
8// -------------------------------------------------------//
9
10#pragma once
11
22#include "shambackends/math.hpp"
27#include <functional>
28#include <utility>
29
30namespace shamtree {
31
40 template<class T>
41 class KarrasRadixTreeField;
42
52 template<class T>
53 class KarrasRadixTreeFieldMultiVar;
54} // namespace shamtree
55
56template<class T>
58
59 public:
61 inline u32 get_total_cell_count() { return buf_field.get_size(); }
62
64
67
68 static inline KarrasRadixTreeField make_empty(sham::DeviceScheduler_ptr dev_sched) {
69 return KarrasRadixTreeField{sham::DeviceBuffer<T>(0, dev_sched)};
70 }
71};
72
73template<class T>
75
76 public:
78 inline u32 get_total_cell_count() { return buf_field.get_size() / nvar; }
79
82
90
91 static inline KarrasRadixTreeFieldMultiVar make_empty(
92 sham::DeviceScheduler_ptr dev_sched, u32 nvar) {
94 }
95};
96
97namespace shamtree {
98
99 template<class T>
100 KarrasRadixTreeField<T> new_empty_karras_radix_tree_field() {
101 auto dev_sched = shamsys::instance::get_compute_scheduler_ptr();
102 return KarrasRadixTreeField<T>(sham::DeviceBuffer<T>(0, dev_sched));
103 }
104
105 template<class T>
106 KarrasRadixTreeFieldMultiVar<T> new_empty_karras_radix_tree_field_multi_var(u32 nvar) {
107 auto dev_sched = shamsys::instance::get_compute_scheduler_ptr();
108 return KarrasRadixTreeFieldMultiVar<T>::make_empty(dev_sched, nvar);
109 }
110
111 template<class T>
112 KarrasRadixTreeField<T> prepare_karras_radix_tree_field(
113 const KarrasRadixTree &tree, KarrasRadixTreeField<T> &&recycled_tree_field) {
114
115 KarrasRadixTreeField<T> ret = std::forward<KarrasRadixTreeField<T>>(recycled_tree_field);
116
117 ret.buf_field.resize(tree.get_total_cell_count());
118
119 return ret;
120 }
121
122 template<class T>
123 KarrasRadixTreeFieldMultiVar<T> prepare_karras_radix_tree_field_multi_var(
124 const KarrasRadixTree &tree, KarrasRadixTreeFieldMultiVar<T> &&recycled_tree_field) {
125
126 KarrasRadixTreeFieldMultiVar<T> ret = std::move(recycled_tree_field);
127
128 ret.buf_field.resize(tree.get_total_cell_count() * ret.nvar);
129
130 return ret;
131 }
132
133 template<class T, class Fct>
134 void propagate_field_up(
135 KarrasRadixTreeField<T> &tree_field, const KarrasRadixTree &tree, Fct fct_combine) {
136
137 sham::DeviceQueue &q = shamsys::instance::get_compute_scheduler().get_queue();
138
139 u32 int_cell_count = tree.get_internal_cell_count();
140
141 if (int_cell_count == 0) {
142 return;
143 }
144
145 auto step = [&]() {
146 auto traverser = tree.get_structure_traverser();
147
149 q,
150 sham::MultiRef{traverser},
151 sham::MultiRef{tree_field.buf_field},
152 int_cell_count,
153 [=](u32 gid, auto tree_traverser, T *__restrict tree_field) {
154 u32 left_child = tree_traverser.get_left_child(gid);
155 u32 right_child = tree_traverser.get_right_child(gid);
156
157 T fieldl = tree_field[left_child];
158 T fieldr = tree_field[right_child];
159
160 T field_val = fct_combine(fieldl, fieldr);
161
162 tree_field[gid] = field_val;
163 });
164 };
165
166 for (u32 i = 0; i < tree.tree_depth; i++) {
167 step();
168 }
169 }
170
171 template<class T, class Fct>
172 KarrasRadixTreeField<T> compute_tree_field(
173 const KarrasRadixTree &tree,
174 KarrasRadixTreeField<T> &&recycled_tree_field,
175 const std::function<void(KarrasRadixTreeField<T> &, u32)> &fct_fill_leaf,
176 Fct fct_combine) {
177
178 auto tree_field = prepare_karras_radix_tree_field(
179 tree, std::forward<KarrasRadixTreeField<T>>(recycled_tree_field));
180
181 fct_fill_leaf(tree_field, tree.get_internal_cell_count());
182
183 propagate_field_up(tree_field, tree, std::forward<Fct>(fct_combine));
184
185 return tree_field;
186 }
187
188 template<class T>
189 KarrasRadixTreeField<T> compute_tree_field_max_field(
190 const KarrasRadixTree &tree,
191 const LeafCellIterator &cell_it,
192 KarrasRadixTreeField<T> &&recycled_tree_field,
193 sham::DeviceBuffer<T> &field) {
194
195 sham::DeviceQueue &q = shamsys::instance::get_compute_scheduler().get_queue();
196
197 auto fill_leafs = [&](KarrasRadixTreeField<T> &tree_field, u32 leaf_offset) {
199 q,
200 sham::MultiRef{field, cell_it},
201 sham::MultiRef{tree_field.buf_field},
202 tree.get_leaf_count(),
203 [leaf_offset](u32 i, const T *field, auto cell_iter, T *comp_field) {
204 // Init with the min value of the type
206
207 cell_iter.for_each_in_leaf_cell(i, [&](u32 obj_id) {
208 field_val = sham::max(field_val, field[obj_id]);
209 });
210
211 comp_field[leaf_offset + i] = field_val;
212 });
213 };
214
215 return compute_tree_field<T>(
216 tree,
217 std::forward<KarrasRadixTreeField<T>>(recycled_tree_field),
218 fill_leafs,
219 [](T a, T b) {
220 return sham::max(a, b);
221 });
222 }
223
224} // namespace shamtree
Header file describing a Node Instance.
std::uint32_t u32
32 bit unsigned integer
A buffer allocated in USM (Unified Shared Memory)
A SYCL queue associated with a device and a context.
DeviceQueue & get_queue(u32 id=0)
Get a reference to a DeviceQueue.
A data structure representing a field with multiple variables per cell for a Karras Radix Tree.
sham::DeviceBuffer< T > buf_field
field data (size = total_cell_count * nvar)
u32 nvar
number of variables per cells
KarrasRadixTreeFieldMultiVar(sham::DeviceBuffer< T > &&buf_field, u32 nvar)
CTOR.
u32 get_total_cell_count()
Get total cell count.
A data structure representing a Karras Radix Tree Field.
u32 get_total_cell_count()
Get internal cell count.
KarrasRadixTreeField(sham::DeviceBuffer< T > &&buf_field)
CTOR.
sham::DeviceBuffer< T > buf_field
left child id (size = internal_count)
void kernel_call(sham::DeviceQueue &q, RefIn in, RefOut in_out, u32 n, Functor &&func, SourceLocation &&callsite=SourceLocation{})
Submit a kernel to a SYCL queue.
void throw_with_loc(std::string message, SourceLocation loc=SourceLocation{})
Throw an exception and append the source location to it.
STL namespace.
A class that references multiple buffers or similar objects.