Shamrock 2025.10.0
Astrophysical Code
Loading...
Searching...
No Matches
LegacyVtkWriter.hpp
Go to the documentation of this file.
1// -------------------------------------------------------//
2//
3// SHAMROCK code for hydrodynamics
4// Copyright (c) 2021-2026 Timothée David--Cléris <tim.shamrock@proton.me>
5// SPDX-License-Identifier: CeCILL Free Software License Agreement v2.1
6// Shamrock is licensed under the CeCILL 2.1 License, see LICENSE for more information
7//
8// -------------------------------------------------------//
9
10#pragma once
11
19#include "shambase/endian.hpp"
20#include "shambase/memory.hpp"
22#include "shambase/time.hpp"
26#include "shambackends/vec.hpp"
27#include "shamcomm/io.hpp"
32#include <fstream>
33#include <sstream>
34#include <string>
35
36namespace shamrock {
37 namespace details {
38
39 template<class T>
40 using repr_t = typename shambase::VectorProperties<T>::component_type;
41
42 template<class T>
43 static constexpr u32 repr_count = shambase::VectorProperties<T>::dimension;
44
45 template<class RT, class T>
46 inline void write_buffer_vtktype(
47 MPI_File fh,
48 sycl::buffer<T> &buf,
49 u32 len,
50 u32 sum_len,
51 bool device_alloc,
52 u64 &file_head_ptr) {
53 StackEntry stack_loc{};
54
55 if (len == 0) {
57 "Cannot call this function with null buffer length");
58 }
59
60 const u32 new_cnt = len * repr_count<T>;
61 const u32 new_cnt_sum = sum_len * repr_count<T>;
62
63 shamlog_debug_mpi_ln("VTK write", new_cnt, new_cnt_sum);
64
65 sycl::queue &q = shamsys::instance::get_compute_queue();
66
67 sycl::buffer<RT> buf_w = shamrock::details::to_vtk_buf_type<RT>(q, buf, len);
68
69 RT *usm_buf;
70 if (device_alloc) {
71
72 usm_buf = sycl::malloc_device<RT>(new_cnt, q);
73
74 auto ev = q.submit([&](sycl::handler &cgh) {
75 sycl::accessor acc_buf{buf_w, cgh, sycl::read_only};
76 RT *ptr = usm_buf;
77 cgh.parallel_for(sycl::range<1>{new_cnt}, [=](sycl::item<1> i) {
78 ptr[i] = (acc_buf[i]);
79 });
80 });
81 ev.wait(); // TODO wait for the event only when doing MPI calls
82
83 } else {
84 usm_buf = sycl::malloc_host<RT>(new_cnt, q);
85
86 {
87 sycl::host_accessor acc_buf{buf_w, sycl::read_only};
88 for (u32 i = 0; i < new_cnt; i++) {
89 usm_buf[i] = (acc_buf[i]);
90 }
91 }
92 }
93
94 shamlog_debug_mpi_ln("VTK write", new_cnt);
95
97 fh, usm_buf, new_cnt, new_cnt_sum, file_head_ptr);
98
99 sycl::free(usm_buf, q);
100 }
101
102 template<class RT, class T>
103 inline void write_buffer_vtktype_no_buf(
104 MPI_File fh, u32 sum_len, bool device_alloc, u64 &file_head_ptr) {
105 StackEntry stack_loc{};
106
107 const u32 new_cnt_sum = sum_len * repr_count<T>;
108
109 shamlog_debug_mpi_ln("VTK write", new_cnt_sum);
110
111 sycl::queue &q = shamsys::instance::get_compute_queue();
112
113 shamalgs::collective::viewed_write_all_fetch_known_total_size<RT>(
114 fh, nullptr, 0, new_cnt_sum, file_head_ptr);
115 }
116 } // namespace details
117
118 enum DataSetTypes { UnstructuredGrid };
119
121 MPI_File mfile{};
122 std::string fname;
123 bool binary;
124
125 u64 file_head_ptr;
126
127 shambase::Timer timer;
128
129 private:
130 inline void head_write(std::string s) {
131 shamalgs::collective::write_header_raw(mfile, s, file_head_ptr);
132 }
133
134 template<class T>
135 inline void write_buf(sycl::buffer<T> &buf, u32 len, u32 sum_len) {
137 details::write_buffer_vtktype<f32>(mfile, buf, len, sum_len, false, file_head_ptr);
138 } else if constexpr (shambase::VectorProperties<T>::is_int_based) {
139 details::write_buffer_vtktype<i32>(mfile, buf, len, sum_len, false, file_head_ptr);
141 details::write_buffer_vtktype<i32>(mfile, buf, len, sum_len, false, file_head_ptr);
142 }
143 }
144
145 template<class T>
146 inline void write_buf_no_buf(u32 sum_len) {
148 details::write_buffer_vtktype_no_buf<f32, T>(mfile, sum_len, false, file_head_ptr);
149 } else if constexpr (shambase::VectorProperties<T>::is_int_based) {
150 details::write_buffer_vtktype_no_buf<i32, T>(mfile, sum_len, false, file_head_ptr);
152 details::write_buffer_vtktype_no_buf<i32, T>(mfile, sum_len, false, file_head_ptr);
153 }
154 }
155
156 template<class T>
157 inline std::string get_buf_type_name() {
159 return "float";
160 } else if constexpr (shambase::VectorProperties<T>::is_int_based) {
161 return "int";
163 return "int";
164 } else {
165 return "unknown";
166 }
167 }
168
169 u64 points_count;
170 bool has_written_points = false;
171
172 u64 cells_count;
173 bool has_written_cells = false;
174
175 public:
176 inline LegacyVtkWriter(std::string fname, bool binary, DataSetTypes type)
177 : fname(fname), binary(binary), file_head_ptr(0_u64) {
178
179 StackEntry stack_loc{};
180
181 timer.start();
182
183 shamlog_debug_ln("VtkWriter", "opening :", fname);
184
185 if (fname.find(".vtk") == std::string::npos) {
187 "the extension should be .vtk");
188 }
189
190 shamcomm::open_reset_file(mfile, fname);
191
192 std::stringstream ss;
193
194 if (binary) {
195 ss << ("# vtk DataFile Version 4.2\nvtk output\nBINARY\n");
196 } else {
197 ss << ("# vtk DataFile Version 4.2\nvtk output\nASCII\n");
198 }
199
200 if (type == UnstructuredGrid) {
201 ss << ("DATASET UNSTRUCTURED_GRID");
202 } else {
203 throw shambase::make_except_with_loc<std::invalid_argument>("unknown dataset type");
204 }
205
206 std::string write_str = ss.str();
207
208 head_write(write_str);
209 }
210
211 template<class T>
212 void write_points(sycl::buffer<sycl::vec<T, 3>> &buf, u32 len) {
213 StackEntry stack_loc{};
214
215 shamlog_debug_mpi_ln("VTK write", "write_points");
216
217 u32 sum_len = shamalgs::collective::allreduce_sum(len);
218
219 std::stringstream ss;
220 ss << "\n\nPOINTS ";
221 ss << sum_len;
222 ss << " " << get_buf_type_name<sycl::vec<T, 3>>();
223 ss << "\n";
224
225 head_write(ss.str());
226
227 write_buf(buf, len, sum_len);
228
229 has_written_points = true;
230 points_count = sum_len;
231 }
232
233 template<class T>
234 void write_points_no_buf() {
235 StackEntry stack_loc{};
236
237 shamlog_debug_mpi_ln("VTK write", "write_points no buf");
238
239 u32 sum_len = shamalgs::collective::allreduce_sum(0);
240
241 std::stringstream ss;
242 ss << "\n\nPOINTS ";
243 ss << sum_len;
244 ss << " " << get_buf_type_name<sycl::vec<T, 3>>();
245 ss << "\n";
246
247 head_write(ss.str());
248
249 write_buf_no_buf<T>(sum_len);
250
251 has_written_points = true;
252 points_count = sum_len;
253 }
254
255 template<class T>
256 void write_points(std::unique_ptr<sycl::buffer<sycl::vec<T, 3>>> &buf, u32 len) {
257 if (len > 0) {
258 write_points(shambase::get_check_ref(buf), len);
259 } else {
260 write_points_no_buf<T>();
261 }
262 }
263
264 template<class T>
265 void write_voxel_cells(
266 sycl::buffer<sycl::vec<T, 3>> &buf_min,
267 sycl::buffer<sycl::vec<T, 3>> &buf_max,
268 u32 len) {
269
270 sycl::buffer<sycl::vec<T, 3>> pos_points(len * 8);
271
272 auto view = shamalgs::collective::fetch_view(len);
273 u32 sum_len = view.total_byte_count;
274 u32 len_offset = view.head_offset;
275
276 shamsys::instance::get_compute_queue().submit([&](sycl::handler &cgh) {
277 sycl::accessor acc_min{buf_min, cgh, sycl::read_only};
278 sycl::accessor acc_max{buf_max, cgh, sycl::read_only};
279
280 sycl::accessor acc_points{pos_points, cgh, sycl::write_only, sycl::no_init};
281
282 cgh.parallel_for(sycl::range<1>{len}, [=](sycl::item<1> id) {
283 u32 idx = id.get_linear_id() * 8;
284
285 sycl::vec<T, 3> pmin = acc_min[id];
286 sycl::vec<T, 3> pmax = acc_max[id];
287
288 acc_points[idx + 0] = pmin;
289 acc_points[idx + 1] = {pmax.x(), pmin.y(), pmin.z()};
290 acc_points[idx + 2] = {pmin.x(), pmax.y(), pmin.z()};
291 acc_points[idx + 3] = {pmax.x(), pmax.y(), pmin.z()};
292 acc_points[idx + 4] = {pmin.x(), pmin.y(), pmax.z()};
293 acc_points[idx + 5] = {pmax.x(), pmin.y(), pmax.z()};
294 acc_points[idx + 6] = {pmin.x(), pmax.y(), pmax.z()};
295 acc_points[idx + 7] = pmax;
296 });
297 });
298
299 write_points(pos_points, len * 8);
300
301 std::stringstream ss;
302 ss << "\n\nCELLS ";
303 ss << sum_len;
304 ss << " " << sum_len * 9;
305 ss << "\n";
306 head_write(ss.str());
307
308 sycl::buffer<i32> idx_cells(len * 9);
309 sycl::buffer<i32> type_cell(len);
310
311 shamsys::instance::get_compute_queue().submit([&](sycl::handler &cgh) {
312 sycl::accessor idxs{idx_cells, cgh, sycl::write_only, sycl::no_init};
313 sycl::accessor cellt{type_cell, cgh, sycl::write_only, sycl::no_init};
314
315 u32 idp_off = len_offset * 8;
316
317 cgh.parallel_for(sycl::range<1>{len}, [=](sycl::item<1> item) {
318 u32 idp = item.get_linear_id() * 8;
319 u32 idx = item.get_linear_id() * 9;
320
321 idxs[idx + 0] = 8;
322 idxs[idx + 1] = idp_off + idp + 0;
323 idxs[idx + 2] = idp_off + idp + 1;
324 idxs[idx + 3] = idp_off + idp + 2;
325 idxs[idx + 4] = idp_off + idp + 3;
326 idxs[idx + 5] = idp_off + idp + 4;
327 idxs[idx + 6] = idp_off + idp + 5;
328 idxs[idx + 7] = idp_off + idp + 6;
329 idxs[idx + 8] = idp_off + idp + 7;
330
331 cellt[item] = 11;
332 });
333 });
334
335 write_buf(idx_cells, len * 9, sum_len * 9);
336
337 std::stringstream ss2;
338 ss2 << "\n\nCELL_TYPES ";
339 ss2 << sum_len;
340 ss2 << "\n";
341 head_write(ss2.str());
342
343 write_buf(type_cell, len, sum_len);
344
345 cells_count = sum_len;
346 has_written_cells = true;
347 }
348
349 void add_point_data_section() {
350
351 if (!has_written_points) {
353 "no points had been written");
354 }
355
356 std::stringstream ss;
357 ss << "\n\nPOINT_DATA ";
358 ss << points_count;
359
360 head_write(ss.str());
361 }
362
363 void add_cell_data_section() {
364
365 if (!has_written_cells) {
367 "no cells had been written");
368 }
369
370 std::stringstream ss;
371 ss << "\n\nCELL_DATA ";
372 ss << cells_count;
373
374 head_write(ss.str());
375 }
376
377 void add_field_data_section(u32 num_field) {
378
379 if (!has_written_points) {
381 "no points had been written");
382 }
383
384 std::stringstream ss;
385 ss << "\nFIELD FieldData ";
386 ss << num_field;
387
388 head_write(ss.str());
389 }
390
391 template<class T>
392 void write_field(std::string name, sycl::buffer<T> &buf, u32 len) {
393
394 u32 sum_len = shamalgs::collective::allreduce_sum(len);
395
396 std::stringstream ss;
397 ss << "\n" << name;
398 ss << " " << details::repr_count<T>;
399 ss << " " << sum_len;
400 ss << " " << get_buf_type_name<T>();
401 ss << "\n";
402 head_write(ss.str());
403
404 write_buf(buf, len, sum_len);
405 }
406
407 template<class T>
408 void write_field_no_buf(std::string name) {
409
410 u32 sum_len = shamalgs::collective::allreduce_sum(0);
411
412 std::stringstream ss;
413 ss << "\n" << name;
414 ss << " " << details::repr_count<T>;
415 ss << " " << sum_len;
416 ss << " " << get_buf_type_name<T>();
417 ss << "\n";
418 head_write(ss.str());
419
420 write_buf_no_buf<T>(sum_len);
421 }
422
423 template<class T>
424 void write_field(std::string name, std::unique_ptr<sycl::buffer<T>> &buf, u32 len) {
425 if (len > 0) {
426 sycl::buffer<T> &buf_ref = shambase::get_check_ref(buf);
427 if (buf_ref.size() < len) {
429 "the buffer is smaller than expected write field size\n buf size = {}, "
430 "cnt = {}",
431 buf_ref.size(),
432 len));
433 }
434 write_field(name, buf_ref, len);
435 } else {
436 write_field_no_buf<T>(name);
437 }
438 }
439
440 inline ~LegacyVtkWriter() {
441 shamlog_debug_mpi_ln("LegacyVtkWriter", "calling : shamcomm::mpi::File_close");
442 shamcomm::mpi::File_close(&mfile);
443 timer.end();
444
445 if (shamcomm::world_rank() == 0) {
446 logger::info_ln(
447 "VTK Dump",
448 shambase::format(
449 "dump to {}\n - took {}, bandwidth = {}/s",
450 fname,
451 timer.get_time_str(),
452 shambase::readable_sizeof(file_head_ptr / timer.elasped_sec())));
453 }
454 }
455
456 LegacyVtkWriter(const LegacyVtkWriter &) = delete;
457 LegacyVtkWriter &operator=(const LegacyVtkWriter &) = delete;
459 : mfile(other.mfile), fname(std::move(other.fname)), binary(other.binary),
460 file_head_ptr(other.file_head_ptr) {} // move constructor
461 LegacyVtkWriter &operator=(LegacyVtkWriter &&other) = delete; // move assignment
462 };
463} // namespace shamrock
This header does the MPI include and wrap MPI calls.
Header file describing a Node Instance.
sycl::queue & get_compute_queue(u32 id=0)
std::uint32_t u32
32 bit unsigned integer
std::uint64_t u64
64 bit unsigned integer
Class Timer measures the time elapsed since the timer was started.
Definition time.hpp:96
std::string get_time_str() const
Converts the stored nanosecond time to a string representation.
Definition time.hpp:117
void end()
Stops the timer and stores the elapsed time in nanoseconds.
Definition time.hpp:111
f64 elasped_sec() const
Converts the stored nanosecond time to a floating point representation in seconds.
Definition time.hpp:123
void start()
Starts the timer.
Definition time.hpp:106
Namespace for internal details of the logs module.
std::string readable_sizeof(double size)
given a sizeof value return a readble string Example : readable_sizeof(1024*1024*1024) -> "1....
Definition string.hpp:139
void throw_with_loc(std::string message, SourceLocation loc=SourceLocation{})
Throw an exception and append the source location to it.
T & get_check_ref(const std::unique_ptr< T > &ptr, SourceLocation loc=SourceLocation())
Takes a std::unique_ptr and returns a reference to the object it holds. It throws a std::runtime_erro...
Definition memory.hpp:110
void open_reset_file(MPI_File &fh, const std::string &fname)
Open a MPI file and remove its content.
Definition io.cpp:24
i32 world_rank()
Gives the rank of the current process in the MPI communicator.
Definition worldInfo.cpp:40
namespace for the main framework
Definition __init__.py:1
void viewed_write_all_fetch_known_total_size(MPI_File fh, T *ptr_data, u64 data_cnt, u64 total_cnt, u64 &file_head_ptr)
Writes data to an MPI file in a collective manner and updates the file head pointer.
Definition io.hpp:68
This file contains the definition for the stacktrace related functionality.