Shamrock 2025.10.0
Astrophysical Code
Loading...
Searching...
No Matches
ReattributeDataUtility.hpp
Go to the documentation of this file.
1// -------------------------------------------------------//
2//
3// SHAMROCK code for hydrodynamics
4// Copyright (c) 2021-2026 Timothée David--Cléris <tim.shamrock@proton.me>
5// SPDX-License-Identifier: CeCILL Free Software License Agreement v2.1
6// Shamrock is licensed under the CeCILL 2.1 License, see LICENSE for more information
7//
8// -------------------------------------------------------//
9
10#pragma once
11
18#include "shambase/string.hpp"
20#include "shamalgs/memory.hpp"
27#include <vector>
28
29namespace shamrock {
30
39 PatchScheduler &sched;
40
41 public:
47 ReattributeDataUtility(PatchScheduler &sched) : sched(sched) {}
48
61 template<class T>
63 SerialPatchTree<T> &sptree, u32 ipos) {
64
65 StackEntry stack_loc{};
66
68
69 sched.patch_data.for_each_patchdata([&](u64 id, shamrock::patch::PatchDataLayer &pdat) {
70 if (!pdat.is_empty()) {
71
72 PatchDataField<T> &pos_field = pdat.get_field<T>(ipos);
73
74 if (pos_field.get_nvar() != 1) {
75 shambase::throw_unimplemented();
76 }
77
78 newid_buf_map.add_obj(
79 id,
80 sptree.compute_patch_owner(
81 shamsys::instance::get_compute_scheduler_ptr(),
82 pos_field.get_buf(),
83 pos_field.get_obj_cnt()));
84
85 bool err_id_in_newid = false;
86 {
87 sycl::host_accessor nid{newid_buf_map.get(id), sycl::read_only};
88 for (u32 i = 0; i < pdat.get_obj_cnt(); i++) {
89 bool err = nid[i] == u64_max;
90 err_id_in_newid = err_id_in_newid || (err);
91 }
92 }
93
94 if (err_id_in_newid) {
96 "a new id could not be computed");
97 }
98 }
99 });
100
101 return newid_buf_map;
102 }
103
117 shambase::DistributedData<sycl::buffer<u64>> new_pid) {
119
120 StackEntry stack_loc{};
121
122 using namespace shamrock::patch;
123
124 std::unordered_map<u64, u64> histogram_extract;
125
126 sched.patch_data.for_each_patchdata(
127 [&](u64 current_pid, shamrock::patch::PatchDataLayer &pdat) {
128 histogram_extract[current_pid] = 0;
129 if (!pdat.is_empty()) {
130
131 sycl::host_accessor nid{new_pid.get(current_pid), sycl::read_only};
132
133 if (false) {
134
135 const u32 cnt = pdat.get_obj_cnt();
136
137 for (u32 i = cnt - 1; i < cnt; i--) {
138 u64 new_pid = nid[i];
139 if (current_pid != new_pid) {
140
141 if (!part_exchange.has_key(current_pid, new_pid)) {
142 part_exchange.add_obj(
143 current_pid,
144 new_pid,
145 PatchDataLayer(sched.get_layout_ptr_old()));
146 }
147
148 part_exchange.for_each(
149 [&](u64 _old_id, u64 _new_id, PatchDataLayer &pdat_int) {
150 if (_old_id == current_pid && _new_id == new_pid) {
151 pdat.extract_element(i, pdat_int);
152 histogram_extract[current_pid]++;
153 }
154 });
155 }
156 }
157 } else {
158 std::vector<u32> keep_ids;
159 std::unordered_map<u64, std::vector<u32>> extract_indexes;
160
161 const u32 cnt = pdat.get_obj_cnt();
162 for (u32 i = 0; i < cnt; i++) {
163 u64 new_pid = nid[i];
164 if (current_pid != new_pid) {
165 extract_indexes[new_pid].push_back(i);
166 histogram_extract[current_pid]++;
167 } else {
168 keep_ids.push_back(i);
169 }
170 }
171
172 for (auto &[new_id, vec] : extract_indexes) {
173
174 u64 new_pid = new_id;
175 std::vector<u32> &idx_extract = vec;
176
177 if (!part_exchange.has_key(current_pid, new_pid)) {
178 part_exchange.add_obj(
179 current_pid,
180 new_pid,
181 PatchDataLayer(sched.get_layout_ptr_old()));
182 }
183
184 part_exchange.for_each(
185 [&](u64 _old_id, u64 _new_id, PatchDataLayer &pdat_int) {
186 if (_old_id == current_pid && _new_id == new_pid) {
187 pdat.append_subset_to(idx_extract, pdat_int);
188 }
189 });
190 }
191
192 sycl::buffer<u32> keep_idx = shamalgs::memory::vec_to_buf(keep_ids);
193 pdat.keep_ids(keep_idx, keep_ids.size());
194 }
195 }
196 });
197
198 for (auto &[k, v] : histogram_extract) {
199 shamlog_debug_ln("ReattributeDataUtility", "patch", k, "extract=", v);
200 }
201
202 return part_exchange;
203 }
204
215 template<class T>
217 SerialPatchTree<T> &sptree, std::string position_field) {
218 StackEntry stack_loc{};
219
220 using namespace shambase;
221 using namespace shamrock::patch;
222
223 u32 ipos = sched.pdl_old().get_field_idx<T>(position_field);
224
225 DistributedData<sycl::buffer<u64>> new_pid = compute_new_pid(sptree, ipos);
226
227 DistributedDataShared<patch::PatchDataLayer> part_exchange = extract_elements(new_pid);
228
229 part_exchange.for_each([](u64 sender, u64 receiver, PatchDataLayer &pdat) {
230 shamlog_debug_ln("ReattributeDataUtility", sender, receiver, pdat.get_obj_cnt());
231 });
232
234
236
237 shamalgs::collective::serialize_sparse_comm<PatchDataLayer>(
238 shamsys::instance::get_compute_scheduler_ptr(),
239 std::move(part_exchange),
240 recv_dat,
241 [&](u64 id) {
242 return sched.get_patch_rank_owner(id);
243 },
244 [](PatchDataLayer &pdat) {
245 shamalgs::SerializeHelper ser(shamsys::instance::get_compute_scheduler_ptr());
246 ser.allocate(pdat.serialize_buf_byte_size());
247 pdat.serialize_buf(ser);
248 return ser.finalize();
249 },
250 [&](sham::DeviceBuffer<u8> &&buf) {
251 // exchange the buffer held by the distrib data and give it to the serializer
253 shamsys::instance::get_compute_scheduler_ptr(),
254 std::forward<sham::DeviceBuffer<u8>>(buf));
255 return PatchDataLayer::deserialize_buf(ser, sched.get_layout_ptr_old());
256 },
257 cache);
258
259 recv_dat.for_each([&](u64 sender, u64 receiver, PatchDataLayer &pdat) {
260 shamlog_debug_ln("Part Exchanges", format("send = {} recv = {}", sender, receiver));
261 sched.patch_data.get_pdat(receiver).insert_elements(pdat);
262 });
263 }
264 };
265
266} // namespace shamrock
Header file describing a Node Instance.
MPI scheduler.
std::uint32_t u32
32 bit unsigned integer
std::uint64_t u64
64 bit unsigned integer
The MPI scheduler.
SchedulerPatchData patch_data
handle the data of the patches of the scheduler
A buffer allocated in USM (Unified Shared Memory)
Container for objects shared between two distributed data elements.
bool has_key(u64 left_id, u64 right_id) const
Check if a patch pair exists in the container.
void for_each(std::function< void(u64, u64, T &)> &&f)
Apply a function to all stored objects.
iterator add_obj(u64 left_id, u64 right_id, T &&obj)
Add an object associated with a patch pair.
Represents a collection of objects distributed across patches identified by a u64 id.
iterator add_obj(u64 id, T &&obj)
Adds a new object to the collection.
T & get(u64 id)
Returns a reference to an object in the collection.
Utility class used to move the objects between patches.
ReattributeDataUtility(PatchScheduler &sched)
Constructor.
shambase::DistributedDataShared< shamrock::patch::PatchDataLayer > extract_elements(shambase::DistributedData< sycl::buffer< u64 > > new_pid)
Extracts elements that do not belong to a patch from the patch data based on the new patch IDs.
shambase::DistributedData< sycl::buffer< u64 > > compute_new_pid(SerialPatchTree< T > &sptree, u32 ipos)
Computes the new patch owner IDs for the objects in the patches based on their position in space.
void reatribute_patch_objects(SerialPatchTree< T > &sptree, std::string position_field)
Reattribute objects based on a given position field.
PatchDataLayer container class, the layout is described in patchdata_layout.
void extract_element(u32 pidx, PatchDataLayer &out_pdat)
extract particle at index pidx and insert it in the provided vectors
sycl::buffer< T > vec_to_buf(const std::vector< T > &buf)
Convert a std::vector to a sycl::buffer
Definition memory.cpp:29
namespace for basic c++ utilities
void throw_with_loc(std::string message, SourceLocation loc=SourceLocation{})
Throw an exception and append the source location to it.
namespace for the main framework
Definition __init__.py:1
constexpr u64 u64_max
u64 max value
main include file for memory algorithms